diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 00000000..07322779 --- /dev/null +++ b/core/__init__.py @@ -0,0 +1,15 @@ +# Environment Code +from .task_interface import TaskWrapper, SubclassTaskWrapper, ReinitTaskWrapper, TaskEnv # , PettingZooTaskWrapper, PettingZooTaskEnv + +# Curriculum Code +from .utils import decorate_all_functions, UsageError, enumerate_axes +from .curriculum_base import Curriculum +from .curriculum_sync_wrapper import (CurriculumWrapper, + MultiProcessingCurriculumWrapper, + MultiProcessingComponents, + RayCurriculumWrapper, + make_multiprocessing_curriculum, + make_ray_curriculum) + +from .environment_sync_wrapper import MultiProcessingSyncWrapper, RaySyncWrapper # , PettingZooMultiProcessingSyncWrapper +from .multivariate_curriculum_wrapper import MultitaskWrapper diff --git a/core/curriculum_base.py b/core/curriculum_base.py new file mode 100644 index 00000000..4ca9aeb0 --- /dev/null +++ b/core/curriculum_base.py @@ -0,0 +1,223 @@ +import typing +import warnings +from typing import Any, Callable, List, Tuple, Union + +import numpy as np +from gymnasium.spaces import Dict + +from syllabus.task_space import TaskSpace + + +# TODO: Move non-generic logic to Uniform class. Allow subclasses to call super for generic error handling +class Curriculum: + """Base class and API for defining curricula to interface with Gym environments. + """ + + def __init__(self, task_space: TaskSpace, random_start_tasks: int = 0, task_names: Callable = None) -> None: + """Initialize the base Curriculum + + :param task_space: the environment's task space from which new tasks are sampled + TODO: Implement this in a way that works with any curriculum, maybe as a wrapper + :param random_start_tasks: Number of uniform random tasks to sample before using the algorithm's sample method, defaults to 0 + TODO: Use task space for this + :param task_names: Names of the tasks in the task space, defaults to None + """ + assert isinstance(task_space, TaskSpace), f"task_space must be a TaskSpace object. Got {type(task_space)} instead." + self.task_space = task_space + self.random_start_tasks = random_start_tasks + self.completed_tasks = 0 + self.task_names = task_names + self.n_updates = 0 + + if self.num_tasks == 0: + warnings.warn("Task space is empty. This will cause errors during sampling if no tasks are added.") + + @property + def requires_step_updates(self) -> bool: + """Returns whether the curriculum requires step updates from the environment. + + :return: True if the curriculum requires step updates, False otherwise + """ + return self.__class__.REQUIRES_STEP_UPDATES + + @property + def requires_episode_updates(self) -> bool: + """Returns whether the curriculum requires episode updates from the environment. + + :return: True if the curriculum requires episode updates, False otherwise + """ + return self.__class__.REQUIRES_EPISODE_UPDATES + + @property + def num_tasks(self) -> int: + """Counts the number of tasks in the task space. + + :return: Returns the number of tasks in the task space if it is countable, TODO: -1 otherwise + """ + return self.task_space.num_tasks + + @property + def tasks(self) -> List[tuple]: + """List all of the tasks in the task space. + + :return: List of tasks if task space is enumerable, TODO: empty list otherwise? + """ + return list(self.task_space.tasks) + + def add_task(self, task: typing.Any) -> None: + # TODO + raise NotImplementedError("This curriculum does not support adding tasks after initialization.") + + def update_task_progress(self, task: typing.Any, progress: Tuple[float, bool], env_id: int = None) -> None: + """Update the curriculum with a task and its progress. + + :param task: Task for which progress is being updated. + :param progress: Progress toward completion or success rate of the given task. 1.0 or True typically indicates a complete task. + """ + self.completed_tasks += 1 + + def update_on_step(self, task: typing.Any, obs: typing.Any, rew: float, term: bool, trunc: bool, info: dict, env_id: int = None) -> None: + """ Update the curriculum with the current step results from the environment. + + :param obs: Observation from teh environment + :param rew: Reward from the environment + :param term: True if the episode ended on this step, False otherwise + :param trunc: True if the episode was truncated on this step, False otherwise + :param info: Extra information from the environment + :raises NotImplementedError: + """ + raise NotImplementedError("This curriculum does not require step updates. Set update_on_step for the environment sync wrapper to False to improve performance and prevent this error.") + + def update_on_step_batch(self, step_results: List[typing.Tuple[Any, Any, int, int, int, int]], env_id: int = None) -> None: + """Update the curriculum with a batch of step results from the environment. + + This method can be overridden to provide a more efficient implementation. It is used + as a convenience function and to optimize the multiprocessing message passing throughput. + + :param step_results: List of step results + """ + tasks, obs, rews, terms, truncs, infos = tuple(step_results) + for i in range(len(obs)): + self.update_on_step(tasks[i], obs[i], rews[i], terms[i], truncs[i], infos[i], env_id=env_id) + + def update_on_episode(self, episode_return: float, episode_length: int, episode_task: Any, env_id: int = None) -> None: + """Update the curriculum with episode results from the environment. + + :param episode_return: Episodic return + :param trajectory: trajectory of (s, a, r, s, ...), defaults to None + :raises NotImplementedError: + """ + # TODO: Add update_on_episode option similar to update-on_step + pass + + def update_on_demand(self, metrics: Dict): + """Update the curriculum with arbitrary inputs. + + + :param metrics: Arbitrary dictionary of information. Can be used to provide gradient/error based + updates from the training process. + :raises NotImplementedError: + """ + raise NotImplementedError + + # TODO: Move to curriculum sync wrapper? + def update(self, update_data: typing.Dict[str, tuple]): + """Update the curriculum with the specified update type. + TODO: Change method header to not use dictionary, use enums? + + :param update_data: Dictionary + :type update_data: Dictionary with "update_type" key which maps to one of ["step", "step_batch", "episode", "on_demand", "task_progress", "add_task", "noop"] and "args" with a tuple of the appropriate arguments for the given "update_type". + :raises NotImplementedError: + """ + + update_type = update_data["update_type"] + args = update_data["metrics"] + env_id = update_data["env_id"] if "env_id" in update_data else None + + if update_type == "step": + self.update_on_step(*args, env_id=env_id) + elif update_type == "step_batch": + self.update_on_step_batch(*args, env_id=env_id) + elif update_type == "episode": + self.update_on_episode(*args, env_id=env_id) + elif update_type == "on_demand": + # Directly pass metrics without expanding + self.update_on_demand(args) + elif update_type == "task_progress": + self.update_task_progress(*args, env_id=env_id) + elif update_type == "task_progress_batch": + tasks, progresses = args + for task, progress in zip(tasks, progresses): + self.update_task_progress(task, progress, env_id=env_id) + elif update_type == "add_task": + self.add_task(args) + elif update_type == "noop": + # Used to request tasks from the synchronization layer + pass + else: + raise NotImplementedError(f"Update type {update_type} not implemented.") + self.n_updates += 1 + + def update_batch(self, update_data: List[Dict]): + """Update the curriculum with batch of updates. + + :param update_data: List of updates or potentially varying types + """ + for update in update_data: + self.update(update) + + def _sample_distribution(self) -> List[float]: + """Returns a sample distribution over the task space. + + Any curriculum that maintains a true probability distribution should implement this method to retrieve it. + """ + raise NotImplementedError + + def _should_use_startup_sampling(self) -> bool: + return self.random_start_tasks > 0 and self.completed_tasks < self.random_start_tasks + + def _startup_sample(self) -> List: + task_dist = [0.0 / self.num_tasks for _ in range(self.num_tasks)] + task_dist[0] = 1.0 + return task_dist + + def sample(self, k: int = 1) -> Union[List, Any]: + """Sample k tasks from the curriculum. + + :param k: Number of tasks to sample, defaults to 1 + :return: Either returns a single task if k=1, or a list of k tasks + """ + assert self.num_tasks > 0, "Task space is empty. Please add tasks to the curriculum before sampling." + + if self._should_use_startup_sampling(): + return self._startup_sample() + + # Use list of indices because np.choice does not play nice with tuple tasks + # tasks = self.tasks + n_tasks = self.num_tasks + task_dist = self._sample_distribution() + task_idx = np.random.choice(list(range(n_tasks)), size=k, p=task_dist) + return task_idx + + def log_metrics(self, writer, step=None, log_full_dist=False): + """Log the task distribution to the provided tensorboard writer. + + :param writer: Tensorboard summary writer. + """ + try: + import wandb + task_dist = self._sample_distribution() + if len(task_dist) > 10 and not log_full_dist: + warnings.warn("Only logging stats for 10 tasks.") + task_dist = task_dist[:10] + if self.task_names: + for idx, prob in enumerate(task_dist): + writer.add_scalar(f"curriculum/task_{self.task_space.task_name(idx)}_prob", prob, step) + else: + for idx, prob in enumerate(task_dist): + writer.add_scalar(f"curriculum/task_{idx}_prob", prob, step) + except ImportError: + warnings.warn("Wandb is not installed. Skipping logging.") + except wandb.errors.Error: + # No need to crash over logging :) + warnings.warn("Failed to log curriculum stats to wandb.") diff --git a/core/curriculum_sync_wrapper.py b/core/curriculum_sync_wrapper.py new file mode 100644 index 00000000..f9866438 --- /dev/null +++ b/core/curriculum_sync_wrapper.py @@ -0,0 +1,323 @@ +import threading +import time +from functools import wraps +from multiprocessing.shared_memory import ShareableList +from typing import List, Tuple + +import ray +from torch.multiprocessing import Lock, SimpleQueue + +from syllabus.core import Curriculum, decorate_all_functions + + +class CurriculumWrapper: + """Wrapper class for adding multiprocessing synchronization to a curriculum. + """ + def __init__(self, curriculum: Curriculum) -> None: + self.curriculum = curriculum + self.task_space = curriculum.task_space + self.unwrapped = curriculum + + @property + def num_tasks(self): + return self.task_space.num_tasks + + def count_tasks(self, task_space=None): + return self.task_space.count_tasks(gym_space=task_space) + + @property + def tasks(self): + return self.task_space.tasks + + @property + def requires_step_updates(self): + return self.curriculum.requires_step_updates + + @property + def requires_episode_updates(self): + return self.curriculum.requires_episode_updates + + def get_tasks(self, task_space=None): + return self.task_space.get_tasks(gym_space=task_space) + + def sample(self, k=1): + return self.curriculum.sample(k=k) + + def update_task_progress(self, task, progress): + self.curriculum.update_task_progress(task, progress) + + def update_on_step(self, task, step, reward, term, trunc): + self.curriculum.update_on_step(task, step, reward, term, trunc) + + def log_metrics(self, writer, step=None): + self.curriculum.log_metrics(writer, step=step) + + def update_on_step_batch(self, step_results): + self.curriculum.update_on_step_batch(step_results) + + def update(self, metrics): + self.curriculum.update(metrics) + + def update_batch(self, metrics): + self.curriculum.update_batch(metrics) + + def add_task(self, task): + self.curriculum.add_task(task) + + +class MultiProcessingComponents: + def __init__(self, task_queue, update_queue): + self.task_queue = task_queue + self.update_queue = update_queue + self._instance_lock = Lock() + self._env_count = ShareableList([0]) + self._task_count = ShareableList([0]) + self._update_count = ShareableList([0]) + self._debug = False + self._verbose = False + + def get_id(self): + with self._instance_lock: + instance_id = self._env_count[0] + self._env_count[0] += 1 + return instance_id + + def put_task(self, task): + self.task_queue.put(task) + if self._debug: + task_count = self.added_task() + if self._verbose: + print(f"Task added to queue. Task count: {task_count}") + + def get_task(self): + task = self.task_queue.get() + if self._debug: + task_count = self.removed_task() + if self._verbose: + print(f"Task removed from queue. Task count: {task_count}") + return task + + def put_update(self, update): + self.update_queue.put(update) + if self._debug: + update_count = self.added_update() + if self._verbose: + print(f"Update added to queue. Update count: {update_count}") + + def get_update(self): + update = self.update_queue.get() + if self._debug: + update_count = self.removed_update() + if self._verbose: + print(f"Update removed from queue. Update count: {update_count}") + + return update + + def added_task(self): + with self._instance_lock: + self._task_count[0] += 1 + task_count = self._task_count[0] + return task_count + + def removed_task(self): + with self._instance_lock: + self._task_count[0] -= 1 + task_count = self._task_count[0] + return task_count + + def get_task_count(self): + with self._instance_lock: + task_count = self._task_count[0] + return task_count + + def get_update_count(self): + with self._instance_lock: + update_count = self._update_count[0] + return update_count + + def added_update(self): + with self._instance_lock: + self._update_count[0] += 1 + update_count = self._update_count[0] + return update_count + + def removed_update(self): + with self._instance_lock: + self._update_count[0] -= 1 + update_count = self._update_count[0] + return update_count + + +class MultiProcessingCurriculumWrapper(CurriculumWrapper): + def __init__( + self, + curriculum: Curriculum, + task_queue: SimpleQueue, + update_queue: SimpleQueue, + sequential_start: bool = True + ): + super().__init__(curriculum) + self.task_queue = task_queue + self.update_queue = update_queue + self.sequential_start = sequential_start + + self.update_thread = None + self.should_update = False + self.added_tasks = [] + self.num_assigned_tasks = 0 + + self._components = MultiProcessingComponents(task_queue, update_queue) + + def start(self): + """ + Start the thread that reads the complete_queue and reads the task_queue. + """ + self.update_thread = threading.Thread(name='update', target=self._update_queues, daemon=True) + self.should_update = True + self.update_thread.start() + + def stop(self): + """ + Stop the thread that reads the complete_queue and reads the task_queue. + """ + # Process final few updates + start = time.time() + end = time.time() + while end - start < 3 and not self.update_queue.empty(): + time.sleep(0.5) + end = time.time() + + self.should_update = False + components = self.get_components() + components._env_count.shm.close() + components._env_count.shm.unlink() + components._update_count.shm.close() + components._update_count.shm.unlink() + components._task_count.shm.close() + components._task_count.shm.unlink() + # components.task_queue.close() + # components.update_queue.close() + + def _update_queues(self): + """ + Continuously process completed tasks and sample new tasks. + """ + # TODO: Refactor long method? Write tests first + # Update curriculum with environment results: + while self.should_update: + requested_tasks = 0 + while not self.update_queue.empty(): + batch_updates = self.get_components().get_update() # Blocks until update is available + + if isinstance(batch_updates, dict): + batch_updates = [batch_updates] + + # Count number of requested tasks + for update in batch_updates: + if "request_sample" in update and update["request_sample"]: + requested_tasks += 1 + + self.update_batch(batch_updates) + + # Sample new tasks + if requested_tasks > 0: + new_tasks = self.curriculum.sample(k=requested_tasks) + for i, task in enumerate(new_tasks): + message = { + "next_task": task, + "sample_id": self.num_assigned_tasks + i, + } + + self.get_components().put_task(message) + self.num_assigned_tasks += requested_tasks + time.sleep(0) + else: + time.sleep(0.01) + + def log_metrics(self, writer, step=None): + super().log_metrics(writer, step=step) + if self.get_components()._debug: + writer.add_scalar("curriculum/updates_in_queue", self.get_components()._update_count[0], step) + writer.add_scalar("curriculum/tasks_in_queue", self.get_components()._task_count[0], step) + + def add_task(self, task): + super().add_task(task) + self.added_tasks.append(task) + + def get_components(self): + return self._components + + +def remote_call(func): + """ + Decorator for automatically forwarding calls to the curriculum via ray remote calls. + + Note that this causes functions to block, and should be only used for operations that do not require parallelization. + """ + @wraps(func) + def wrapper(self, *args, **kw): + f_name = func.__name__ + parent_func = getattr(CurriculumWrapper, f_name) + child_func = getattr(self, f_name) + + # Only forward call if subclass does not explicitly override the function. + if child_func == parent_func: + curriculum_func = getattr(self.curriculum, f_name) + return ray.get(curriculum_func.remote(*args, **kw)) + return wrapper + + +def make_multiprocessing_curriculum(curriculum, **kwargs): + """ + Helper function for creating a MultiProcessingCurriculumWrapper. + """ + task_queue = SimpleQueue() + update_queue = SimpleQueue() + + mp_curriculum = MultiProcessingCurriculumWrapper(curriculum, task_queue, update_queue, **kwargs) + mp_curriculum.start() + return mp_curriculum + + +@ray.remote +class RayWrapper(CurriculumWrapper): + def __init__(self, curriculum: Curriculum) -> None: + super().__init__(curriculum) + + +@decorate_all_functions(remote_call) +class RayCurriculumWrapper(CurriculumWrapper): + """ + Subclass of LearningProgress Curriculum that uses Ray to share tasks and receive feedback + from the environment. The only change is the @ray.remote decorator on the class. + + The @decorate_all_functions(remote_call) annotation automatically forwards all functions not explicitly + overridden here to the remote curriculum. This is intended to forward private functions of Curriculum subclasses + for convenience. + # TODO: Implement the Curriculum methods explicitly + """ + def __init__(self, curriculum, actor_name="curriculum") -> None: + super().__init__(curriculum) + self.curriculum = RayWrapper.options(name=actor_name).remote(curriculum) + self.unwrapped = None + self.task_space = curriculum.task_space + self.added_tasks = [] + + # If you choose to override a function, you will need to forward the call to the remote curriculum. + # This method is shown here as an example. If you remove it, the same functionality will be provided automatically. + def sample(self, k: int = 1): + return ray.get(self.curriculum.sample.remote(k=k)) + + def update_on_step_batch(self, step_results: List[Tuple[int, int, int, int]]) -> None: + ray.get(self.curriculum._on_step_batch.remote(step_results)) + + def add_task(self, task): + super().add_task(task) + self.added_tasks.append(task) + + +def make_ray_curriculum(curriculum, actor_name="curriculum", **kwargs): + """ + Helper function for creating a RayCurriculumWrapper. + """ + return RayCurriculumWrapper(curriculum, actor_name=actor_name, **kwargs) diff --git a/core/environment_sync_wrapper.py b/core/environment_sync_wrapper.py new file mode 100644 index 00000000..6edee7cc --- /dev/null +++ b/core/environment_sync_wrapper.py @@ -0,0 +1,348 @@ +from typing import Any, Callable, Dict + +import gymnasium as gym +import numpy as np +import ray +from gymnasium.utils.step_api_compatibility import step_api_compatibility + +from syllabus.core import Curriculum, MultiProcessingCurriculumWrapper, MultiProcessingComponents, TaskEnv, TaskWrapper +from syllabus.task_space import TaskSpace + + +class MultiProcessingSyncWrapper(gym.Wrapper): + """ + This wrapper is used to set the task on reset for a Gym environments running + on parallel processes created using multiprocessing.Process. Meant to be used + with a QueueLearningProgressCurriculum running on the main process. + """ + + def __init__(self, + env, + components: MultiProcessingComponents, + update_on_step: bool = False, # TODO: Fine grained control over which step elements are used. Controlled by curriculum? + update_on_progress: bool = False, # TODO: Fine grained control over which step elements are used. Controlled by curriculum? + batch_size: int = 100, + buffer_size: int = 2, # Having an extra task in the buffer minimizes wait time at reset + task_space: TaskSpace = None, + global_task_completion: Callable[[Curriculum, np.ndarray, float, bool, Dict[str, Any]], bool] = None): + # TODO: reimplement global task progress metrics + assert isinstance(task_space, TaskSpace), f"task_space must be a TaskSpace object. Got {type(task_space)} instead." + super().__init__(env) + self.env = env + self.components = components + self._latest_task = None + self.task_queue = components.task_queue + self.update_queue = components.update_queue + self.task_space = task_space + self.update_on_step = update_on_step + self.update_on_progress = update_on_progress + self.batch_size = batch_size + self.global_task_completion = global_task_completion + self.task_progress = 0.0 + self._batch_step = 0 + self.instance_id = components.get_id() + + self.episode_length = 0 + self.episode_return = 0 + + # Create batch buffers for step updates + if self.update_on_step: + self._obs = [None] * self.batch_size + self._rews = np.zeros(self.batch_size, dtype=np.float32) + self._terms = np.zeros(self.batch_size, dtype=bool) + self._truncs = np.zeros(self.batch_size, dtype=bool) + self._infos = [None] * self.batch_size + self._tasks = [None] * self.batch_size + self._task_progresses = [None] * self.batch_size + + # Request initial task + assert buffer_size > 0, "Buffer size must be greater than 0 to sample initial task for envs." + for _ in range(buffer_size): + update = { + "update_type": "noop", + "metrics": None, + "request_sample": True, + } + self.components.put_update(update) + + def reset(self, *args, **kwargs): + self.step_updates = [] + self.task_progress = 0.0 + self.episode_length = 0 + self.episode_return = 0 + + message = self.components.get_task() # Blocks until a task is available + next_task = self.task_space.decode(message["next_task"]) + self._latest_task = next_task + + # Add any new tasks + if "added_tasks" in message: + added_tasks = message["added_tasks"] + for add_task in added_tasks: + self.env.add_task(add_task) + return self.env.reset(*args, new_task=next_task, **kwargs) + + def step(self, action): + obs, rew, term, trunc, info = step_api_compatibility(self.env.step(action), output_truncation_bool=True) + self.episode_length += 1 + self.episode_return += rew + self.task_progress = info.get("task_completion", 0.0) + + # Update curriculum with step info + if self.update_on_step: + self._obs[self._batch_step] = obs + self._rews[self._batch_step] = rew + self._terms[self._batch_step] = term + self._truncs[self._batch_step] = trunc + self._infos[self._batch_step] = info + self._tasks[self._batch_step] = self.task_space.encode(self.get_task()) + self._task_progresses[self._batch_step] = self.task_progress + self._batch_step += 1 + + # Send batched updates + if self._batch_step >= self.batch_size or term or trunc: + updates = self._package_step_updates() + self.components.put_update(updates) + self._batch_step = 0 + + # Episode update + if term or trunc: + # Task progress + task_update = { + "update_type": "task_progress", + "metrics": ((self.task_space.encode(self.env.task), self.task_progress)), + "env_id": self.instance_id, + "request_sample": False, + } + episode_update = { + "update_type": "episode", + "metrics": (self.episode_return, self.episode_length, self.task_space.encode(self.env.task)), + "env_id": self.instance_id, + "request_sample": True + } + self.components.put_update([task_update, episode_update]) + + return obs, rew, term, trunc, info + + def _package_step_updates(self): + step_batch = { + "update_type": "step_batch", + "metrics": ([self._tasks[:self._batch_step], self._obs[:self._batch_step], self._rews[:self._batch_step], self._terms[:self._batch_step], self._truncs[:self._batch_step], self._infos[:self._batch_step]],), + "env_id": self.instance_id, + "request_sample": False + } + update = [step_batch] + + if self.update_on_progress: + task_batch = { + "update_type": "task_progress_batch", + "metrics": (self._tasks[:self._batch_step], self._task_progresses[:self._batch_step],), + "env_id": self.instance_id, + "request_sample": False + } + update.append(task_batch) + return update + + def add_task(self, task): + update = { + "update_type": "add_task", + "metrics": task + } + self.update_queue.put(update) + + def get_task(self): + # Allow user to reject task + if hasattr(self.env, "task"): + return self.env.task + return self._latest_task + + def __getattr__(self, attr): + env_attr = getattr(self.env, attr, None) + if env_attr is not None: + return env_attr + + +# TODO: Fix this and refactor +# class PettingZooMultiProcessingSyncWrapper(BaseParallelWraper): +# """ +# This wrapper is used to set the task on reset for a Gym environments running +# on parallel processes created using multiprocessing.Process. Meant to be used +# with a QueueLearningProgressCurriculum running on the main process. +# """ +# def __init__(self, +# env, +# task_queue: SimpleQueue, +# update_queue: SimpleQueue, +# update_on_step: bool = True, # TODO: Fine grained control over which step elements are used. Controlled by curriculum? +# default_task=None, +# task_space: TaskSpace = None, +# global_task_completion: Callable[[Curriculum, np.ndarray, float, bool, Dict[str, Any]], bool] = None): +# super().__init__(env) +# self.env = env +# self.task_queue = task_queue +# self.update_queue = update_queue +# self.task_space = task_space +# self.update_on_step = update_on_step +# self.global_task_completion = global_task_completion +# self.task_completion = 0.0 +# self.warned_once = False +# self.step_results = [] +# if task_space.contains(default_task): +# self.default_task = default_task + +# # Request initial task +# update = { +# "update_type": "noop", +# "metrics": None, +# "request_sample": True +# } +# self.update_queue.put(update) + +# @property +# def agents(self): +# return self.env.agents + +# def reset(self, *args, **kwargs): +# self.step_results = [] + +# # Update curriculum +# update = { +# "update_type": "complete", +# "metrics": (self.task_space.encode(self.env.task), self.task_completion), +# "request_sample": True +# } +# self.update_queue.put(update) +# self.task_completion = 0.0 + +# # Sample new task +# if self.task_queue.empty(): +# # Choose default task if it is set, or keep the current task +# next_task = self.default_task if self.default_task is not None else self.task_space.sample() +# if not self.warned_once: +# print("\nTask queue was empty, selecting default task. This warning will not print again for this environment.\n") +# self.warned_once = False +# else: +# message = self.task_queue.get() +# next_task = self.task_space.decode(message["next_task"]) +# if "add_task" in message: +# self.env.add_task(message["add_task"]) +# return self.env.reset(*args, new_task=next_task, **kwargs) + +# def step(self, action): +# obs, rew, term, trunc, info = self.env.step(action) + +# if "task_completion" in info: +# if self.global_task_completion is not None: +# self.task_completion = self.global_task_completion(self.curriculum, obs, rew, term, trunc, info) +# else: +# self.task_completion = info["task_completion"] + +# if self.update_on_step: +# self.step_results.append((obs, rew, term, trunc, info)) +# if len(self.step_results) >= 2000: +# update = { +# "update_type": "step_batch", +# "metrics": (self.step_results,), +# "request_sample": False +# } +# self.update_queue.put(update) +# self.step_results = [] + +# return obs, rew, term, trunc, info + +# def add_task(self, task): +# update = { +# "update_type": "add_task", +# "metrics": task +# } +# self.update_queue.put(update) + +# def __getattr__(self, attr): +# env_attr = getattr(self.env, attr, None) +# if env_attr: +# return env_attr + + +class RaySyncWrapper(gym.Wrapper): + """ + This wrapper is used to set the task on reset for a Gym environments running + on parallel processes created using ray. Meant to be used with a + RayLearningProgressCurriculum running on the main process. + """ + def __init__(self, + env, + update_on_step: bool = True, + task_space: gym.Space = None, + global_task_completion: Callable[[Curriculum, np.ndarray, float, bool, Dict[str, Any]], bool] = None): + assert isinstance(env, TaskWrapper) or isinstance(env, TaskEnv) or isinstance(env, PettingZooTaskWrapper), "Env must implement the task API" + super().__init__(env) + self.env = env + self.update_on_step = update_on_step # Disable to improve performance 10x + self.task_space = task_space + self.curriculum = ray.get_actor("curriculum") + self.task_completion = 0.0 + self.global_task_completion = global_task_completion + self.step_results = [] + + def reset(self, *args, **kwargs): + self.step_results = [] + + # Update curriculum + update = { + "update_type": "task_progress", + "metrics": (self.env.task, self.task_completion), + "request_sample": True + } + self.curriculum.update.remote(update) + self.task_completion = 0.0 + + # Sample new task + sample = ray.get(self.curriculum.sample.remote()) + next_task = sample[0] + + return self.env.reset(*args, new_task=next_task, **kwargs) + + def step(self, action): + obs, rew, term, trunc, info = self.env.step(action) + + if "task_completion" in info: + if self.global_task_completion is not None: + # TODO: Hide rllib interface? + self.task_completion = self.global_task_completion(self.curriculum, obs, rew, term, trunc, info) + else: + self.task_completion = info["task_completion"] + + # TODO: Optimize + if self.update_on_step: + self.step_results.append((obs, rew, term, trunc, info)) + if len(self.step_results) >= 1000 or term or trunc: + update = { + "update_type": "step_batch", + "metrics": (self.step_results,), + "request_sample": False + } + self.curriculum.update.remote(update) + self.step_results = [] + + return obs, rew, term, trunc, info + + def change_task(self, new_task): + """ + Changes the task of the existing environment to the new_task. + + Each environment will implement tasks differently. The easiest system would be to call a + function or set an instance variable to change the task. + + Some environments may need to be reset or even reinitialized to change the task. + If you need to reset or re-init the environment here, make sure to check + that it is not in the middle of an episode to avoid unexpected behavior. + """ + self.env.change_task(new_task) + + def add_task(self, task): + self.curriculum.add_task.remote(task) + + def __getattr__(self, attr): + env_attr = getattr(self.env, attr, None) + if env_attr: + return env_attr diff --git a/core/multivariate_curriculum_wrapper.py b/core/multivariate_curriculum_wrapper.py new file mode 100644 index 00000000..ca556366 --- /dev/null +++ b/core/multivariate_curriculum_wrapper.py @@ -0,0 +1,82 @@ +import itertools +import typing +from typing import Any, Callable, List, Union + +import numpy as np +from gymnasium.spaces import Dict, Tuple +from syllabus.core import Curriculum, CurriculumWrapper +from syllabus.task_space import TaskSpace + + +class MultitaskWrapper(CurriculumWrapper): + """ + Uniform sampling for task spaces with multiple subspaces (Tuple or Dict) + """ + # TODO: How do I use curriculum wrappers with the make_curriculum functions? + def __init__(self, *args, num_components: int, component_names: List[str] = None, **kwargs): + super().__init__(*args, **kwargs) + self.num_components = num_components + + # Duplicate task space for each component + if num_components is not None: + self.task_space = TaskSpace(Tuple([self.task_space.gym_space for _ in range(num_components)]), (tuple(self.task_space.tasks),) * num_components) + elif component_names is not None: + self.task_space = TaskSpace(Dict({name: self.task_space.gym_space for name in component_names}), {name: self.task_space.tasks for name in component_names}) + + def _sample_distribution(self) -> List[float]: + """ + Returns a sample distribution over the task space. + """ + # Uniform distribution + if isinstance(self.task_space.gym_space, Tuple): + multivariate_dists = [self.curriculum._sample_distribution() for _ in self.task_space.gym_space.spaces] + elif isinstance(self.task_space.gym_space, Dict): + multivariate_dists = {name: self.curriculum._sample_distribution() for name in self.task_space.gym_space.keys()} + else: + raise NotImplementedError("Multivariate task space must be Tuple or Dict.") + return multivariate_dists + + def sample(self, k: int = 1) -> Union[List, Any]: + """ + Sample k tasks from the curriculum. + """ + assert self.num_tasks > 0, "Task space is empty. Please add tasks to the curriculum before sampling." + + tasks = [] + for _ in range(k): + sample_dist = self._sample_distribution() + if isinstance(sample_dist, list): + task_components = [] + for dist in sample_dist: + task_components.append(self.curriculum.sample(k=1)[0]) + tasks.append(tuple(task_components)) + return tasks + + + multitask_dist = self._sample_distribution() + # TODO: Clean and comment + if isinstance(self.task_space.gym_space, Dict): + multitasks = [] + for _ in range(k): + multitask = {} + # TODO: Provide easier access to gym_space properties? + for (space_name, task_space), task_dist in zip(self.task_space.tasks.items(), multitask_dist): + n_tasks = len(task_dist) + task_idx = np.random.choice(list(range(n_tasks)), size=1, p=task_dist) + multitask[space_name] = np.array([self.get_tasks(task_space)[i] for i in task_idx]) + multitasks.append(multitask) + return multitasks + elif isinstance(self.task_space.gym_space, Tuple): + multitask = [] + for tasks, task_dist in zip(self.task_space.tasks, multitask_dist): + print(tasks) + n_tasks = len(task_dist) + task_idx = np.random.choice(list(range(n_tasks)), size=k, p=task_dist) + multitask.append(np.array([tasks[i] for i in task_idx])) + multitask = np.array(multitask) + return np.moveaxis(multitask, -1, 0) + else: + raise NotImplementedError("Multivariate task space must be Tuple or Dict.") + + def log_metrics(self, writer, step=None): + raise NotImplementedError("Multitask curriculum does not support logging metrics.") \ No newline at end of file diff --git a/core/task_interface/__init__.py b/core/task_interface/__init__.py new file mode 100644 index 00000000..d4deb940 --- /dev/null +++ b/core/task_interface/__init__.py @@ -0,0 +1,5 @@ +# Environment Code +from .task_wrapper import TaskWrapper # , PettingZooTaskWrapper +from .subclass_task_wrapper import SubclassTaskWrapper +from .reinit_task_wrapper import ReinitTaskWrapper +from .environment_task_env import TaskEnv # , PettingZooTaskEnv diff --git a/core/task_interface/environment_task_env.py b/core/task_interface/environment_task_env.py new file mode 100644 index 00000000..a4a9e2fd --- /dev/null +++ b/core/task_interface/environment_task_env.py @@ -0,0 +1,89 @@ +import gymnasium as gym +# import pettingzoo + + +class TaskEnv(gym.Env): + # TODO: Update to new TaskSpace API + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.task_completion = 0.0 + self.task_space = None + self.task = None + + def reset(self, *args, **kwargs): + if "new_task" in kwargs: + new_task = kwargs.pop("new_task") + self.change_task(new_task) + # TODO: Handle failure case for change task + self.task = new_task + + obs, info = super().reset(*args, **kwargs) + return self.observation(obs), info + + def change_task(self, new_task): + """ + Changes the task of the existing environment to the new_task. + + Each environment will implement tasks differently. The easiest system would be to call a + function or set an instance variable to change the task. + + Some environments may need to be reset or even reinitialized to change the task. + If you need to reset or re-init the environment here, make sure to check + that it is not in the middle of an episode to avoid unexpected behavior. + """ + raise NotImplementedError + + def add_task(self, task): + raise NotImplementedError("This environment does not support adding tasks.") + + def _task_completion(self, obs, rew, term, trunc, info) -> float: + """ + Implement this function to indicate whether the selected task has been completed. + This can be determined using the observation, rewards, term, trunc, info or internal values + from the environment. Intended to be used for automatic curricula. + Returns a boolean or float value indicating binary completion or scalar degree of completion. + """ + return 1.0 if term or trunc else 0.0 + + def _encode_goal(self): + """ + Implement this method to indicate which task is selected to the agent. + Returns: Numpy array encoding the goal. + """ + return None + + def observation(self, observation): + """ + Adds the goal encoding to the observation. + Override to add additional task-specific observations. + Returns a modified observation. + TODO: Complete this implementation and find way to support centralized encodings + """ + # Add goal to observation + goal_encoding = self._encode_goal() + if goal_encoding is not None: + observation['goal'] = goal_encoding + + return observation + + def step(self, action): + """ + Steps the environment with the given action. + Unlike the typical Gym environment, this method should also add the + {"task_completion": self.task_completion()} key to the info dictionary + to support curricula that rely on this metric. + """ + raise NotImplementedError + + +# class PettingZooTaskEnv(TaskEnv, pettingzoo.ParallelEnv): +# def __init__(self): +# super().__init__() +# self.task = None + +# @property +# def agents(self): +# return self.agents + +# def get_current_task(self): +# return self.task diff --git a/core/task_interface/reinit_task_wrapper.py b/core/task_interface/reinit_task_wrapper.py new file mode 100644 index 00000000..4e8a24aa --- /dev/null +++ b/core/task_interface/reinit_task_wrapper.py @@ -0,0 +1,119 @@ +""" Task wrapper for NLE that can change tasks at reset using the NLE's task definition format. """ +import time +from typing import Callable, Tuple, Union + +import gymnasium as gym + +from .task_wrapper import TaskWrapper + + +class ReinitTaskWrapper(TaskWrapper): + """ + This is a general wrapper for tasks defined as subclasses of a base environment. + + This wrapper reinitializes the environment with the provided env function at the start of each episode. + This is a simple, general solution to using Syllabus with tasks that need to be reinitialized, but it is inefficient. + It's likely that you can achieve better performance by using a more specialized wrapper. + """ + def __init__(self, env: gym.Env, env_fn: Callable, task_space: gym.Space = None): + super().__init__(env) + + self.env_fn = env_fn + self.task_envs = {} # Save instance of each task environment to avoid reinitializing + self.task_space = task_space + self.task = None + + def encode_task(self, task): + """ + Override to convert task description into an element of the MultiDiscrete task space. + This is the identity function by default. + """ + return task + + def decode_task(self, encoding): + """ + Override to convert element of the MultiDiscrete task space into format usable by the reinit env_fn. + This is the identity function by default. + """ + return encoding + + def reset(self, new_task: Union[Tuple, int, float] = None, **kwargs): + """ + Resets the environment along with all available tasks, and change the current task. + """ + # Change task if new one is provided + if new_task is not None: + self.change_task(new_task) + + obs, info = self.env.reset(**kwargs) + return self.observation(obs), info + + def change_task(self, new_task: Union[Tuple, int, float]): + """ + Change task by directly editing environment class. + + This ensures that all instance variables are reset, not just the ones for the current task. + We do this efficiently by keeping track of which reset functions have already been called, + since very few tasks override reset. If new_task is provided, we change the task before + calling the final reset. + """ + + # Update current task + if new_task not in self.task_envs: + self.task_envs[new_task] = self.env_fn(self.decode_task(new_task)) + + self.env = self.task_envs[new_task] + self.task = new_task + + def step(self, action): + """ + Step through environment and update task completion. + """ + obs, rew, term, trunc, info = self.env.step(action) + info["task_completion"] = self._task_completion(obs, rew, term, trunc, info) + return self.observation(obs), rew, term, trunc, info + + +if __name__ == "__main__": + from nle.env.tasks import (NetHackEat, NetHackGold, NetHackOracle, + NetHackScore, NetHackScout, NetHackStaircase, + NetHackStaircasePet) + + def run_episode(env, task: str = None, verbose=1): + env.reset(new_task=task) + task_name = type(env.unwrapped).__name__ + term = trunc = False + ep_rew = 0 + while not (term or trunc): + action = env.action_space.sample() + _, rew, term, trunc, _ = env.step(action) + ep_rew += rew + if verbose: + print(f"Episodic reward for {task_name}: {ep_rew}") + + print("Testing NethackTaskWrapper") + N_EPISODES = 100 + + # Initialize NLE + def create_env(task): + task_class = [NetHackScore, NetHackStaircase, NetHackStaircasePet, NetHackOracle, NetHackGold, NetHackEat, NetHackScout][task] + return task_class() + + nethack_env = NetHackScore() + nethack_task_env = ReinitTaskWrapper(nethack_env, create_env) + + start_time = time.time() + + for _ in range(N_EPISODES): + run_episode(nethack_task_env, verbose=0) + + end_time = time.time() + print(f"Run time same task: {end_time - start_time}") + start_time = time.time() + + for _ in range(N_EPISODES): + nethack_task = gym.spaces.Discrete(7).sample() + run_episode(nethack_task_env, task=nethack_task, verbose=0) + + end_time = time.time() + print(f"Run time swapping tasks: {end_time - start_time}") diff --git a/core/task_interface/subclass_task_wrapper.py b/core/task_interface/subclass_task_wrapper.py new file mode 100644 index 00000000..e558181c --- /dev/null +++ b/core/task_interface/subclass_task_wrapper.py @@ -0,0 +1,102 @@ +""" Task wrapper for NLE that can change tasks at reset using the NLE's task definition format. """ +import copy +from typing import List + +import gymnasium as gym +import numpy as np +from gymnasium import spaces +from syllabus.task_space import TaskSpace + +from .task_wrapper import TaskWrapper + + +class SubclassTaskWrapper(TaskWrapper): + # TODO: Automated tests + """ + This is a general wrapper for tasks defined as subclasses of a base environment. + + This wrapper reinitializes the environment with the provided env function at the start of each episode. + This is a simple, general solution to using Syllabus with tasks that need to be reinitialized, but it is inefficient. + It's likely that you can achieve better performance by using a more specialized wrapper. + """ + def __init__(self, env: gym.Env, task_subclasses: List[gym.Env] = None, **env_init_kwargs): + super().__init__(env) + + self.task_list = task_subclasses + self.task_space = TaskSpace(spaces.Discrete(len(self.task_list)), self.task_list) + self._env_init_kwargs = env_init_kwargs # kwargs for reinitializing the base environment + + # Add goal space to observation + self.observation_space = copy.deepcopy(self.env.observation_space) + self.observation_space["goal"] = spaces.MultiBinary(len(self.task_list)) + + # Tracking episode end + self.done = True + + # Initialize all tasks + original_class = self.env.__class__ + for task in self.task_list: + self.env.__class__ = task + self.env.__init__(**self._env_init_kwargs) + + self.env.__class__ = original_class + self.env.__init__(**self._env_init_kwargs) + + @property + def current_task(self): + return self.env.__class__ + + def _task_name(self, task): + return self.task.__name__ + + def reset(self, new_task: int = None, **kwargs): + """ + Resets the environment along with all available tasks, and change the current task. + """ + # Change task if new one is provided + if new_task is not None: + self.change_task(new_task) + + self.done = False + obs, info = self.env.reset(**kwargs) + return self.observation(obs), info + + def change_task(self, new_task: int): + """ + Change task by directly editing environment class. + + This ensures that all instance variables are reset, not just the ones for the current task. + We do this efficiently by keeping track of which reset functions have already been called, + since very few tasks override reset. If new_task is provided, we change the task before + calling the final reset. + """ + # Ignore new task if mid episode + if self.current_task.__init__ != self._task_class(new_task).__init__ and not self.done: + raise RuntimeError("Cannot change task mid-episode.") + + # Ignore if task is unknown + if new_task >= len(self.task_list): + raise RuntimeError(f"Unknown task {new_task}.") + + # Update current task + prev_task = self.task + self.task = new_task + self.env.__class__ = self._task_class(new_task) + + # If task requires reinitialization + if type(self.env).__init__ != prev_task.__init__: + self.env.__init__(**self._env_init_kwargs) + + def _encode_goal(self): + goal_encoding = np.zeros(len(self.task_list)) + goal_encoding[self.task] = 1 + return goal_encoding + + def step(self, action): + """ + Step through environment and update task completion. + """ + obs, rew, term, trunc, info = self.env.step(action) + self.done = term or trunc + info["task_completion"] = self._task_completion(obs, rew, term, trunc, info) + return self.observation(obs), rew, term, trunc, info diff --git a/core/task_interface/task_wrapper.py b/core/task_interface/task_wrapper.py new file mode 100644 index 00000000..89c94ff3 --- /dev/null +++ b/core/task_interface/task_wrapper.py @@ -0,0 +1,99 @@ +import gymnasium as gym +# import pettingzoo +# from pettingzoo.utils.wrappers.base_parallel import BaseParallelWraper + + +class TaskWrapper(gym.Wrapper): + # TODO: Update to new TaskSpace API + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.task_completion = 0.0 + self.task_space = None + self.task = None # TODO: Would making this a property protect from accidental overriding? + + def reset(self, *args, **kwargs): + if "new_task" in kwargs: + new_task = kwargs.pop("new_task") + self.change_task(new_task) + # TODO: Handle failure case for change task + self.task = new_task + return self.observation(super().reset(*args, **kwargs)) + + def change_task(self, new_task): + """ + Changes the task of the existing environment to the new_task. + + Each environment will implement tasks differently. The easiest system would be to call a + function or set an instance variable to change the task. + + Some environments may need to be reset or even reinitialized to change the task. + If you need to reset or re-init the environment here, make sure to check + that it is not in the middle of an episode to avoid unexpected behavior. + """ + raise NotImplementedError + + def add_task(self, task): + raise NotImplementedError("This environment does not support adding tasks.") + + def _task_completion(self, obs, rew, term, trunc, info) -> float: + """ + Implement this function to indicate whether the selected task has been completed. + This can be determined using the observation, rewards, term, trunc, info or internal values + from the environment. Intended to be used for automatic curricula. + Returns a boolean or float value indicating binary completion or scalar degree of completion. + """ + return 1.0 if term or trunc else 0.0 + + def _encode_goal(self): + """ + Implement this method to indicate which task is selected to the agent. + Returns: Numpy array encoding the goal. + """ + return None + + def observation(self, observation): + """ + Adds the goal encoding to the observation. + Override to add additional task-specific observations. + Returns a modified observation. + TODO: Complete this implementation and find way to support centralized encodings + """ + # Add goal to observation + goal_encoding = self._encode_goal() + if goal_encoding is not None: + observation['goal'] = goal_encoding + + return observation + + def step(self, action): + obs, rew, term, trunc, info = self.env.step(action) + + # Determine completion status of the current task + self.task_completion = self._task_completion(obs, rew, term, trunc, info) + info["task_completion"] = self.task_completion + + return self.observation(obs), rew, term, trunc, info + + def __getattr__(self, attr): + env_attr = self.env.__class__.__dict__.get(attr, None) + + if env_attr and callable(env_attr): + return env_attr + + +# class PettingZooTaskWrapper(TaskWrapper, BaseParallelWraper): +# def __init__(self, env: pettingzoo.ParallelEnv): +# super().__init__(env) +# self.task = None + +# @property +# def agents(self): +# return self.env.agents + +# def __getattr__(self, attr): +# env_attr = getattr(self.env, attr, None) +# if env_attr: +# return env_attr + +# def get_current_task(self): +# return self.current_task \ No newline at end of file diff --git a/core/utils.py b/core/utils.py new file mode 100644 index 00000000..538b9df0 --- /dev/null +++ b/core/utils.py @@ -0,0 +1,32 @@ +from itertools import product +from typing import Union + +import numpy as np + + +def decorate_all_functions(function_decorator): + def decorator(cls): + for base_cls in cls.__bases__: + for name, obj in vars(base_cls).items(): + parent_func = getattr(base_cls, name) + child_func = getattr(cls, name) + + # Only apply decorator to functions not overridden by subclass. + if callable(obj) and child_func == parent_func: + setattr(cls, name, function_decorator(obj)) + return cls + return decorator + + +class UsageError(Exception): + ... + pass + + +def enumerate_axes(list_or_size: Union[np.ndarray, int]): + if isinstance(list_or_size, int) or isinstance(list_or_size, np.int64): + return tuple(range(list_or_size)) + elif isinstance(list_or_size, list) or isinstance(list_or_size, np.ndarray): + return tuple(product(*[enumerate_axes(x) for x in list_or_size])) + else: + raise NotImplementedError(f"{type(list_or_size)}") diff --git a/curricula/__init__.py b/curricula/__init__.py new file mode 100644 index 00000000..650a40e7 --- /dev/null +++ b/curricula/__init__.py @@ -0,0 +1,11 @@ +import sys + +from .domain_randomization import DomainRandomization +from .learning_progress import LearningProgressCurriculum +from .noop import NoopCurriculum +from .plr.central_plr_wrapper import CentralizedPrioritizedLevelReplay +from .plr.plr_wrapper import PrioritizedLevelReplay +from .plr.task_sampler import TaskSampler +from .sequential import SequentialCurriculum +from .simple_box import SimpleBoxCurriculum +from .annealing_box import AnnealingBoxCurriculum diff --git a/curricula/annealing_box.py b/curricula/annealing_box.py new file mode 100644 index 00000000..101981c7 --- /dev/null +++ b/curricula/annealing_box.py @@ -0,0 +1,56 @@ +import typing +from typing import Any, List, Union, Sequence, SupportsFloat, SupportsInt, Tuple +import numpy as np + +from gymnasium.spaces import Box +from syllabus.core import Curriculum + +class AnnealingBoxCurriculum(Curriculum): + REQUIRES_STEP_UPDATES = True + REQUIRES_EPISODE_UPDATES = False + REQUIRES_CENTRAL_UPDATES = False + + + def __init__( + self, + *curriculum_args, + start_values: List[SupportsFloat], + end_values: List[SupportsFloat], + total_steps: Tuple[int, List[int]], + **curriculum_kwargs, + ): + super().__init__(*curriculum_args, **curriculum_kwargs) + assert isinstance( + self.task_space.gym_space, Box + ), "AnnealingBoxCurriculum only supports Box task spaces." + + self.start_values = np.array(start_values, dtype=np.float32) + self.end_values = np.array(end_values, dtype=np.float32) + + # Convert total_steps to list if necessary + if isinstance(total_steps, SupportsInt): + total_steps = [total_steps] + self.total_steps = np.array(total_steps, dtype=np.int32) + + assert len(self.start_values) == len(self.end_values), "Length of start_values and end_values must be the same." + assert all(x > 0 for x in self.total_steps), "All elements of total_steps must be greater than 0." + + self.current_step = 0 + + def update_on_step(self, *args, **kwargs) -> None: + """ + Update the curriculum based on the current training timestep. + """ + self.current_step += 1 + + def sample(self, k: int = 1) -> Union[List, Any]: + """ + Sample k tasks from the curriculum. + """ + # Linear annealing from start_values to end_values + annealed_values = ( + self.start_values + (self.end_values - self.start_values) * + np.minimum(self.current_step, self.total_steps) / self.total_steps + ) + + return [annealed_values.copy() for _ in range(k)] diff --git a/curricula/domain_randomization.py b/curricula/domain_randomization.py new file mode 100644 index 00000000..e4183904 --- /dev/null +++ b/curricula/domain_randomization.py @@ -0,0 +1,21 @@ +from typing import Any, List + +from syllabus.core import Curriculum + + +class DomainRandomization(Curriculum): + """A simple but strong baseline for curriculum learning that uniformly samples a task from the task space. + """ + REQUIRES_STEP_UPDATES = False + REQUIRES_EPISODE_UPDATES = False + REQUIRES_CENTRAL_UPDATES = False + + def _sample_distribution(self) -> List[float]: + """ + Returns a sample distribution over the task space. + """ + # Uniform distribution + return [1.0 / self.num_tasks for _ in range(self.num_tasks)] + + def add_task(self, task: Any) -> None: + self.task_space.add_task(task) diff --git a/curricula/learning_progress.py b/curricula/learning_progress.py new file mode 100644 index 00000000..e9e92f58 --- /dev/null +++ b/curricula/learning_progress.py @@ -0,0 +1,198 @@ +import math +import random +import warnings +from collections import defaultdict +from typing import List + +import numpy as np +from gymnasium.spaces import Discrete, MultiDiscrete +from scipy.stats import norm + +from syllabus.core import Curriculum +from syllabus.task_space import TaskSpace + + +class LearningProgressCurriculum(Curriculum): + """ + Provides an interface for tracking success rates of discrete tasks and sampling tasks + based on their success rate using the method from https://arxiv.org/abs/2106.14876. + TODO: Support task spaces aside from Discrete + """ + REQUIRES_STEP_UPDATES = False + REQUIRES_EPISODE_UPDATES = False + REQUIRES_CENTRAL_UPDATES = False + + def __init__(self, *args, ema_alpha=0.1, **kwargs): + super().__init__(*args, **kwargs) + self.ema_alpha = ema_alpha + + assert isinstance(self.task_space.gym_space, (Discrete, MultiDiscrete)) + self._p_fast = np.zeros(self.num_tasks) + self._p_slow = np.zeros(self.num_tasks) + + def update_task_progress(self, task: int, progress: float, env_id: int = None): + """ + Update the success rate for the given task using a fast and slow exponential moving average. + """ + if task is None or progress == 0.0: + return + super().update_task_progress(task, progress) + + self._p_fast[task] = (progress * self.ema_alpha) + (self._p_fast[task] * (1.0 - self.ema_alpha)) + self._p_slow[task] = (self._p_fast[task] * self.ema_alpha) + (self._p_slow[task] * (1.0 - self.ema_alpha)) + + def _learning_progress(self, task: int, reweight: bool = True) -> float: + """ + Compute the learning progress metric for the given task. + """ + slow = self._reweight(self._p_slow[task]) if reweight else self._p_slow[task] + fast = self._reweight(self._p_fast[task]) if reweight else self._p_fast[task] + return abs(fast - slow) + + def _reweight(self, p: np.ndarray, p_theta: float = 0.1) -> float: + """ + Reweight the given success rate using the reweighting function from the paper. + """ + numerator = p * (1.0 - p_theta) + denominator = p + p_theta * (1.0 - 2.0 * p) + return numerator / denominator + + def _sigmoid(self, x: np.ndarray): + return 1 / (1 + np.exp(-x)) + + def _sample_distribution(self) -> List[float]: + if self.num_tasks == 0: + return [] + + task_dist = np.ones(self.num_tasks) / self.num_tasks + + task_lps = self._learning_progress(np.asarray(self.tasks)) + posidxs = [i for i, lp in enumerate(task_lps) if lp > 0] + zeroout = len(posidxs) > 0 + + subprobs = task_lps[posidxs] if zeroout else task_lps + std = np.std(subprobs) + subprobs = (subprobs - np.mean(subprobs)) / (std if std else 1) # z-score + subprobs = self._sigmoid(subprobs) # sigmoid + subprobs = subprobs / np.sum(subprobs) # normalize + if zeroout: + # If some tasks have nonzero progress, zero out the rest + task_dist = np.zeros(len(task_lps)) + task_dist[posidxs] = subprobs + else: + # If all tasks have 0 progress, return uniform distribution + task_dist = subprobs + + return task_dist + + def on_step(self, obs, rew, term, trunc, info) -> None: + """ + Update the curriculum with the current step results from the environment. + """ + pass + + +if __name__ == "__main__": + def sample_binomial(p=0.5, n=200): + success = 0.0 + for _ in range(n): + rand = random.random() + if rand < p: + success += 1.0 + return success / n + + def generate_history(center=0, curve=1.0, n=100): + center = center if center else n / 2.0 + + def sig(x, x_0=center, curve=curve): + return 1.0 / (1.0 + math.e**(curve * (x_0 - x))) + history = [] + probs = [] + success_prob = 0.0 + for i in range(n): + probs.append(success_prob) + history.append(sample_binomial(p=success_prob)) + success_prob = sig(i) + return history, probs + + tasks = range(20) + histories = {task: generate_history(center=random.randint(0, 100), curve=random.random()) for task in tasks} + + curriculum = LearningProgressCurriculum(TaskSpace(len(tasks))) + for i in range(len(histories[0][0])): + for task in tasks: + curriculum.update_task_progress(task, histories[task][0][i]) + if i > 10: + distribution = curriculum._sample_distribution() + print("[", end="") + for j, prob in enumerate(distribution): + print(f"{prob:.3f}", end="") + if j < len(distribution) - 1: + print(", ", end="") + print("]") + + tasks = [0] + histories = {task: generate_history(n=200, center=75, curve=0.1) for task in tasks} + curriculum = LearningProgressCurriculum(TaskSpace(len(tasks))) + lp_raw = [] + lp_reweight = [] + p_fast = [] + p_slow = [] + true_probs = [] + estimates = [] + for estimate, true_prob in zip(histories[0][0], histories[0][1]): + curriculum.update_task_progress(tasks[0], estimate) + lp_raw.append(curriculum._learning_progress(tasks[0], reweight=False)) + lp_reweight.append(curriculum._learning_progress(tasks[0])) + p_fast.append(curriculum._p_fast[0]) + p_slow.append(curriculum._p_slow[0]) + true_probs.append(true_prob) + estimates.append(estimate) + + try: + import matplotlib.pyplot as plt + + # TODO: Plot probabilities + def plot_history(true_probs, estimates, p_slow, p_fast, lp_reweight, lp_raw): + x_axis = range(0, len(true_probs)) + plt.plot(x_axis, true_probs, color="#222222", label="True Success Probability") + plt.plot(x_axis, estimates, color="#888888", label="Estimated Success Probability") + plt.plot(x_axis, p_slow, color="#ee3333", label="p_slow") + plt.plot(x_axis, p_fast, color="#33ee33", label="p_fast") + plt.plot(x_axis, lp_raw, color="#c4c25b", label="Learning Progress") + plt.plot(x_axis, lp_reweight, color="#1544ee", label="Learning Progress Reweighted") + plt.xlabel('Time step') + plt.ylabel('Learning Progress') + plt.legend() + plt.show() + + plot_history(true_probs, estimates, p_slow, p_fast, lp_reweight, lp_raw) + + # Reweight Plot + x_axis = np.linspace(0, 1, num=100) + y_axis = [] + for x in x_axis: + y_axis.append(curriculum._reweight(x)) + plt.plot(x_axis, y_axis, color="blue", label="p_theta = 0.1") + plt.xlabel('p') + plt.ylabel('reweight') + plt.legend() + plt.show() + + # Z-score plot + tasks = [i for i in range(50)] + curriculum = LearningProgressCurriculum(TaskSpace(len(tasks))) + histories = {task: generate_history(n=200, center=60, curve=0.09) for task in tasks} + for i in range(len(histories[0][0])): + for task in tasks: + curriculum.update_task_progress(task, histories[task][0][i]) + distribution = curriculum._sample_distribution() + x_axis = np.linspace(-3, 3, num=len(distribution)) + sigmoid_axis = curriculum._sigmoid(x_axis) + plt.plot(x_axis, norm.pdf(x_axis, 0, 1), color="blue", label="Normal distribution") + plt.plot(x_axis, sigmoid_axis, color="orange", label="Sampling weight") + plt.xlabel('Z-scored distributed learning progress') + plt.legend() + plt.show() + except ImportError: + warnings.warn("Matplotlib not installed. Plotting will not work.") diff --git a/curricula/noop.py b/curricula/noop.py new file mode 100644 index 00000000..fb5d8ae4 --- /dev/null +++ b/curricula/noop.py @@ -0,0 +1,62 @@ +from typing import Any, List, Union + +from syllabus.core import Curriculum + + +class NoopCurriculum(Curriculum): + """ + Used to to test API without a curriculum. + """ + REQUIRES_STEP_UPDATES = True + REQUIRES_EPISODE_UPDATES = False + REQUIRES_CENTRAL_UPDATES = False + + def __init__(self, default_task, *curriculum_args, **curriculum_kwargs): + super().__init__(*curriculum_args, **curriculum_kwargs) + self.default_task = self.task_space.encode(default_task) + + def sample(self, k: int = 1) -> Union[List, Any]: + """ + Sample k tasks from the curriculum. + """ + return [self.default_task for _ in range(k)] + + def update_task_progress(self, task, success_prob, env_id: int = None) -> None: + """ + Update the curriculum with a task and its success probability upon + success or failure. + """ + pass + + def update_on_step(self, task, obs, rew, term, trunc, info, env_id: int = None) -> None: + """ + Update the curriculum with the current step results from the environment. + """ + pass + + def update_on_step_batch(self, step_results, env_id: int = None) -> None: + """ + Update the curriculum with a batch of step results from the environment. + """ + pass + + def update_on_episode(self, episode_return, episode_length, episode_task, env_id: int = None) -> None: + """ + Update the curriculum with episode results from the environment. + """ + pass + + def update_on_demand(self, metrics): + """ + Update the curriculum with arbitrary inputs. + """ + pass + + def add_task(self, task: tuple) -> None: + pass + + def update(self, update_data): + """ + Update the curriculum with the specified update type. + """ + pass diff --git a/curricula/plr/__init__.py b/curricula/plr/__init__.py new file mode 100644 index 00000000..75e6bc2d --- /dev/null +++ b/curricula/plr/__init__.py @@ -0,0 +1,3 @@ +from .central_plr_wrapper import CentralizedPrioritizedLevelReplay +from .plr_wrapper import PrioritizedLevelReplay +from .task_sampler import TaskSampler diff --git a/curricula/plr/central_plr_wrapper.py b/curricula/plr/central_plr_wrapper.py new file mode 100644 index 00000000..7f69ea85 --- /dev/null +++ b/curricula/plr/central_plr_wrapper.py @@ -0,0 +1,239 @@ +import warnings +from typing import Any, Dict, List, Tuple, Union + +import gymnasium as gym +import torch +from gymnasium.spaces import Discrete, MultiDiscrete + +from syllabus.core import Curriculum, enumerate_axes +from syllabus.task_space import TaskSpace + +from .task_sampler import TaskSampler + + +class RolloutStorage(object): + def __init__( + self, + num_steps: int, + num_processes: int, + requires_value_buffers: bool, + action_space: gym.Space = None, + ): + self._requires_value_buffers = requires_value_buffers + self.tasks = torch.zeros(num_steps, num_processes, 1, dtype=torch.int) + self.masks = torch.ones(num_steps + 1, num_processes, 1) + + if requires_value_buffers: + self.returns = torch.zeros(num_steps + 1, num_processes, 1) + self.rewards = torch.zeros(num_steps, num_processes, 1) + self.value_preds = torch.zeros(num_steps + 1, num_processes, 1) + else: + if action_space is None: + raise ValueError( + "Action space must be provided to PLR for strategies 'policy_entropy', 'least_confidence', 'min_margin'" + ) + self.action_log_dist = torch.zeros(num_steps, num_processes, action_space.n) + + self.num_steps = num_steps + self.step = 0 + + def to(self, device): + self.masks = self.masks.to(device) + self.tasks = self.tasks.to(device) + if self._requires_value_buffers: + self.rewards = self.rewards.to(device) + self.value_preds = self.value_preds.to(device) + self.returns = self.returns.to(device) + else: + self.action_log_dist = self.action_log_dist.to(device) + + def insert(self, masks, action_log_dist=None, value_preds=None, rewards=None, tasks=None): + if self._requires_value_buffers: + assert (value_preds is not None and rewards is not None), "Selected strategy requires value_preds and rewards" + if len(rewards.shape) == 3: + rewards = rewards.squeeze(2) + self.value_preds[self.step].copy_(torch.as_tensor(value_preds)) + self.rewards[self.step].copy_(torch.as_tensor(rewards)[:, None]) + self.masks[self.step + 1].copy_(torch.as_tensor(masks)[:, None]) + else: + self.action_log_dist[self.step].copy_(action_log_dist) + if tasks is not None: + assert isinstance(tasks[0], int), "Provided task must be an integer" + self.tasks[self.step].copy_(torch.as_tensor(tasks)[:, None]) + self.step = (self.step + 1) % self.num_steps + + def after_update(self): + self.masks[0].copy_(self.masks[-1]) + + def compute_returns(self, next_value, gamma, gae_lambda): + assert self._requires_value_buffers, "Selected strategy does not use compute_rewards." + self.value_preds[-1] = next_value + gae = 0 + for step in reversed(range(self.rewards.size(0))): + delta = ( + self.rewards[step] + + gamma * self.value_preds[step + 1] * self.masks[step + 1] + - self.value_preds[step] + ) + gae = delta + gamma * gae_lambda * self.masks[step + 1] * gae + self.returns[step] = gae + self.value_preds[step] + + +class CentralizedPrioritizedLevelReplay(Curriculum): + """ Prioritized Level Replay (PLR) Curriculum. + + Args: + task_space (TaskSpace): The task space to use for the curriculum. + *curriculum_args: Positional arguments to pass to the curriculum. + task_sampler_kwargs_dict (dict): Keyword arguments to pass to the task sampler. See TaskSampler for details. + action_space (gym.Space): The action space to use for the curriculum. Required for some strategies. + device (str): The device to use to store curriculum data, either "cpu" or "cuda". + num_steps (int): The number of steps to store in the rollouts. + num_processes (int): The number of parallel environments. + gamma (float): The discount factor used to compute returns + gae_lambda (float): The GAE lambda value. + suppress_usage_warnings (bool): Whether to suppress warnings about improper usage. + **curriculum_kwargs: Keyword arguments to pass to the curriculum. + """ + REQUIRES_STEP_UPDATES = False + REQUIRES_EPISODE_UPDATES = False + REQUIRES_CENTRAL_UPDATES = True + + def __init__( + self, + task_space: TaskSpace, + *curriculum_args, + task_sampler_kwargs_dict: dict = None, + action_space: gym.Space = None, + device: str = "cpu", + num_steps: int = 256, + num_processes: int = 64, + gamma: float = 0.999, + gae_lambda: float = 0.95, + suppress_usage_warnings=False, + **curriculum_kwargs, + ): + # Preprocess curriculum intialization args + if task_sampler_kwargs_dict is None: + task_sampler_kwargs_dict = {} + + self._strategy = task_sampler_kwargs_dict.get("strategy", None) + if not isinstance(task_space.gym_space, Discrete) and not isinstance(task_space.gym_space, MultiDiscrete): + raise ValueError( + f"Task space must be discrete or multi-discrete, got {task_space.gym_space}." + ) + if "num_actors" in task_sampler_kwargs_dict and task_sampler_kwargs_dict['num_actors'] != num_processes: + warnings.warn(f"Overwriting 'num_actors' {task_sampler_kwargs_dict['num_actors']} in task sampler kwargs with PLR num_processes {num_processes}.") + task_sampler_kwargs_dict["num_actors"] = num_processes + super().__init__(task_space, *curriculum_args, **curriculum_kwargs) + + self._num_steps = num_steps # Number of steps stored in rollouts and used to update task sampler + self._num_processes = num_processes # Number of parallel environments + self._gamma = gamma + self._gae_lambda = gae_lambda + self._supress_usage_warnings = suppress_usage_warnings + self._task2index = {task: i for i, task in enumerate(self.tasks)} + self._task_sampler = TaskSampler(self.tasks, action_space=action_space, **task_sampler_kwargs_dict) + self._rollouts = RolloutStorage( + self._num_steps, + self._num_processes, + self._task_sampler.requires_value_buffers, + action_space=action_space, + ) + self._rollouts.to(device) + # TODO: Fix this feature + self.num_updates = 0 # Used to ensure proper usage + self.num_samples = 0 # Used to ensure proper usage + + def _validate_metrics(self, metrics: Dict): + try: + masks = torch.Tensor(1 - metrics["dones"]) + tasks = metrics["tasks"] + tasks = [self._task2index[t] for t in tasks] + except KeyError as e: + raise KeyError( + "Missing or malformed PLR update. Must include 'masks', and 'tasks', and all tasks must be in the task space" + ) from e + + # Parse optional update values (required for some strategies) + value = next_value = rew = action_log_dist = None + if self._task_sampler.requires_value_buffers: + if "value" not in metrics or "rew" not in metrics: + raise KeyError( + f"'value' and 'rew' must be provided in every update for the strategy {self._strategy}." + ) + value = metrics["value"] + rew = metrics["rew"] + else: + try: + action_log_dist = metrics["action_log_dist"] + except KeyError as e: + raise KeyError( + f"'action_log_dist' must be provided in every update for the strategy {self._strategy}." + ) from e + + if self._task_sampler.requires_value_buffers: + try: + next_value = metrics["next_value"] + except KeyError as e: + raise KeyError( + f"'next_value' must be provided in the update every {self.num_steps} steps for the strategy {self._strategy}." + ) from e + + return masks, tasks, value, rew, action_log_dist, next_value + + def update_on_demand(self, metrics: Dict): + """ + Update the curriculum with arbitrary inputs. + """ + self.num_updates += 1 + masks, tasks, value, rew, action_log_dist, next_value = self._validate_metrics(metrics) + + # Update rollouts + self._rollouts.insert( + masks, + action_log_dist=action_log_dist, + value_preds=value, + rewards=rew, + tasks=tasks, + ) + + # Update task sampler + if self._rollouts.step == 0: + if self._task_sampler.requires_value_buffers: + self._rollouts.compute_returns(next_value, self._gamma, self._gae_lambda) + self._task_sampler.update_with_rollouts(self._rollouts) + self._rollouts.after_update() + self._task_sampler.after_update() + + def _sample_distribution(self) -> List[float]: + """ + Returns a sample distribution over the task space. + """ + return self._task_sampler.sample_weights() + + def sample(self, k: int = 1) -> Union[List, Any]: + self.num_samples += 1 + if self._should_use_startup_sampling(): + return self._startup_sample() + else: + return [self._task_sampler.sample() for _ in range(k)] + + def _enumerate_tasks(self, space): + assert isinstance(space, Discrete) or isinstance(space, MultiDiscrete), f"Unsupported task space {space}: Expected Discrete or MultiDiscrete" + if isinstance(space, Discrete): + return list(range(space.n)) + else: + return list(enumerate_axes(space.nvec)) + + def log_metrics(self, writer, step=None): + """ + Log the task distribution to the provided tensorboard writer. + """ + super().log_metrics(writer, step) + metrics = self._task_sampler.metrics() + writer.add_scalar("curriculum/proportion_seen", metrics["proportion_seen"], step) + writer.add_scalar("curriculum/score", metrics["score"], step) + for task in list(self.task_space.tasks)[:10]: + writer.add_scalar(f"curriculum/task_{task - 1}_score", metrics["task_scores"][task - 1], step) + writer.add_scalar(f"curriculum/task_{task - 1}_staleness", metrics["task_staleness"][task - 1], step) diff --git a/curricula/plr/plr_wrapper.py b/curricula/plr/plr_wrapper.py new file mode 100644 index 00000000..9c808ddc --- /dev/null +++ b/curricula/plr/plr_wrapper.py @@ -0,0 +1,295 @@ +import warnings +from typing import Any, Dict, List, Tuple, Union + +import gymnasium as gym +import numpy as np +import torch +from gymnasium.spaces import Discrete, MultiDiscrete + +from syllabus.core import Curriculum, UsageError, enumerate_axes +from syllabus.task_space import TaskSpace + +from .task_sampler import TaskSampler + + +class RolloutStorage(object): + def __init__( + self, + num_steps: int, + num_processes: int, + requires_value_buffers: bool, + observation_space: gym.Space, + action_space: gym.Space = None, + get_value=None, + ): + self.num_steps = num_steps + self.buffer_steps = num_steps * 4 # Hack to prevent overflow from lagging updates. + self.num_processes = num_processes + self._requires_value_buffers = requires_value_buffers + self._get_value = get_value + self.tasks = torch.zeros(self.buffer_steps, num_processes, 1, dtype=torch.int) + self.masks = torch.ones(self.buffer_steps + 1, num_processes, 1) + self.obs = [[[0] for _ in range(self.num_processes)]] * self.buffer_steps + self.env_steps = [0] * num_processes + self.ready_buffers = set() + + if requires_value_buffers: + self.returns = torch.zeros(self.buffer_steps + 1, num_processes, 1) + self.rewards = torch.zeros(self.buffer_steps, num_processes, 1) + self.value_preds = torch.zeros(self.buffer_steps + 1, num_processes, 1) + else: + if action_space is None: + raise ValueError( + "Action space must be provided to PLR for strategies 'policy_entropy', 'least_confidence', 'min_margin'" + ) + self.action_log_dist = torch.zeros(self.buffer_steps, num_processes, action_space.n) + + self.num_steps = num_steps + + def to(self, device): + self.masks = self.masks.to(device) + self.tasks = self.tasks.to(device) + if self._requires_value_buffers: + self.rewards = self.rewards.to(device) + self.value_preds = self.value_preds.to(device) + self.returns = self.returns.to(device) + else: + self.action_log_dist = self.action_log_dist.to(device) + + def insert_at_index(self, env_index, mask=None, action_log_dist=None, obs=None, reward=None, task=None, steps=1): + step = self.env_steps[env_index] + end_step = step + steps + + if mask is not None: + self.masks[step + 1:end_step + 1, env_index].copy_(torch.as_tensor(mask[:, None])) + + if obs is not None: + for s in range(step, end_step): + self.obs[s][env_index] = obs[s - step] + + if reward is not None: + self.rewards[step:end_step, env_index].copy_(torch.as_tensor(reward[:, None])) + + if action_log_dist is not None: + self.action_log_dist[step:end_step, env_index].copy_(torch.as_tensor(action_log_dist[:, None])) + + if task is not None: + try: + int(task[0]) + except TypeError: + assert isinstance(task, int), f"Provided task must be an integer, got {task[0]} with type {type(task[0])} instead." + self.tasks[step:end_step, env_index].copy_(torch.as_tensor(np.array(task)[:, None])) + + self.env_steps[env_index] += steps + if env_index not in self.ready_buffers and self.env_steps[env_index] >= self.num_steps: + self.ready_buffers.add(env_index) + + def _get_values(self, env_index): + if self._get_value is None: + raise UsageError("Selected strategy requires value predictions. Please provide get_value function.") + for step in range(0, self.num_steps, self.num_processes): + obs = self.obs[step: step + self.num_processes][env_index] + values = self._get_value(obs) + + # Reshape values if necessary + if len(values.shape) == 3: + warnings.warn(f"Value function returned a 3D tensor of shape {values.shape}. Attempting to squeeze last dimension.") + values = torch.squeeze(values, -1) + if len(values.shape) == 1: + warnings.warn(f"Value function returned a 1D tensor of shape {values.shape}. Attempting to unsqueeze last dimension.") + values = torch.unsqueeze(values, -1) + + self.value_preds[step: step + self.num_processes, env_index].copy_(values) + + def after_update(self, env_index): + # After consuming the first num_steps of data, remove them and shift the remaining data in the buffer + self.tasks = self.tasks.roll(-self.num_steps, 0) + self.masks = self.masks.roll(-self.num_steps, 0) + self.obs[0:][env_index] = self.obs[self.num_steps: self.buffer_steps][env_index] + + if self._requires_value_buffers: + self.returns = self.returns.roll(-self.num_steps, 0) + self.rewards = self.rewards.roll(-self.num_steps, 0) + self.value_preds = self.value_preds.roll(-self.num_steps, 0) + else: + self.action_log_dist = self.action_log_dist.roll(-self.num_steps, 0) + + self.env_steps[env_index] -= self.num_steps + self.ready_buffers.remove(env_index) + + def compute_returns(self, gamma, gae_lambda, env_index): + assert self._requires_value_buffers, "Selected strategy does not use compute_rewards." + self._get_values(env_index) + gae = 0 + for step in reversed(range(self.rewards.size(0), self.num_steps)): + delta = ( + self.rewards[step, env_index] + + gamma * self.value_preds[step + 1, env_index] * self.masks[step + 1, env_index] + - self.value_preds[step, env_index] + ) + gae = delta + gamma * gae_lambda * self.masks[step + 1, env_index] * gae + self.returns[step, env_index] = gae + self.value_preds[step, env_index] + + +def null(x): + return None + + +class PrioritizedLevelReplay(Curriculum): + """ Prioritized Level Replay (PLR) Curriculum. + + Args: + task_space (TaskSpace): The task space to use for the curriculum. + *curriculum_args: Positional arguments to pass to the curriculum. + task_sampler_kwargs_dict (dict): Keyword arguments to pass to the task sampler. See TaskSampler for details. + action_space (gym.Space): The action space to use for the curriculum. Required for some strategies. + device (str): The device to use to store curriculum data, either "cpu" or "cuda". + num_steps (int): The number of steps to store in the rollouts. + num_processes (int): The number of parallel environments. + gamma (float): The discount factor used to compute returns + gae_lambda (float): The GAE lambda value. + suppress_usage_warnings (bool): Whether to suppress warnings about improper usage. + **curriculum_kwargs: Keyword arguments to pass to the curriculum. + """ + REQUIRES_STEP_UPDATES = True + REQUIRES_EPISODE_UPDATES = False + REQUIRES_CENTRAL_UPDATES = False + + def __init__( + self, + task_space: TaskSpace, + observation_space: gym.Space, + *curriculum_args, + task_sampler_kwargs_dict: dict = None, + action_space: gym.Space = None, + device: str = "cpu", + num_steps: int = 256, + num_processes: int = 64, + gamma: float = 0.999, + gae_lambda: float = 0.95, + suppress_usage_warnings=False, + get_value=null, + get_action_log_dist=null, + **curriculum_kwargs, + ): + # Preprocess curriculum intialization args + if task_sampler_kwargs_dict is None: + task_sampler_kwargs_dict = {} + + self._strategy = task_sampler_kwargs_dict.get("strategy", None) + if not isinstance(task_space.gym_space, Discrete) and not isinstance(task_space.gym_space, MultiDiscrete): + raise ValueError( + f"Task space must be discrete or multi-discrete, got {task_space.gym_space}." + ) + if "num_actors" in task_sampler_kwargs_dict and task_sampler_kwargs_dict['num_actors'] != num_processes: + warnings.warn(f"Overwriting 'num_actors' {task_sampler_kwargs_dict['num_actors']} in task sampler kwargs with PLR num_processes {num_processes}.") + task_sampler_kwargs_dict["num_actors"] = num_processes + super().__init__(task_space, *curriculum_args, **curriculum_kwargs) + + self._num_steps = num_steps # Number of steps stored in rollouts and used to update task sampler + self._num_processes = num_processes # Number of parallel environments + self._gamma = gamma + self._gae_lambda = gae_lambda + self._supress_usage_warnings = suppress_usage_warnings + self._get_action_log_dist = get_action_log_dist + self._task2index = {task: i for i, task in enumerate(self.tasks)} + + self._task_sampler = TaskSampler(self.tasks, action_space=action_space, **task_sampler_kwargs_dict) + self._rollouts = RolloutStorage( + self._num_steps, + self._num_processes, + self._task_sampler.requires_value_buffers, + observation_space, + action_space=action_space, + get_value=get_value if get_value is not None else null, + ) + self._rollouts.to(device) + + def set_value_fn(self, value_fn): + self._rollouts._get_value = value_fn + + def _sample_distribution(self) -> List[float]: + """ + Returns a sample distribution over the task space. + """ + return self._task_sampler.sample_weights() + + def sample(self, k: int = 1) -> Union[List, Any]: + if self._should_use_startup_sampling(): + return self._startup_sample() + else: + return [self._task_sampler.sample() for _ in range(k)] + + def update_on_step(self, task, obs, rew, term, trunc, info, env_id: int = None) -> None: + """ + Update the curriculum with the current step results from the environment. + """ + assert env_id is not None, "env_id must be provided for PLR updates." + if env_id >= self._num_processes: + warnings.warn(f"Env index {env_id} is greater than the number of processes {self._num_processes}. Using index {env_id % self._num_processes} instead.") + env_id = env_id % self._num_processes + + # Update rollouts + self._rollouts.insert_at_index( + env_id, + mask=np.array([not (term or trunc)]), + action_log_dist=self._get_action_log_dist(obs), + reward=np.array([rew]), + obs=np.array([obs]), + ) + + # Update task sampler + if env_id in self._rollouts.ready_buffers: + self._update_sampler(env_id) + + def update_on_step_batch( + self, step_results: List[Tuple[int, Any, int, bool, bool, Dict]], env_id: int = None + ) -> None: + """ + Update the curriculum with a batch of step results from the environment. + """ + assert env_id is not None, "env_id must be provided for PLR updates." + if env_id >= self._num_processes: + warnings.warn(f"Env index {env_id} is greater than the number of processes {self._num_processes}. Using index {env_id % self._num_processes} instead.") + env_id = env_id % self._num_processes + + tasks, obs, rews, terms, truncs, infos = step_results + self._rollouts.insert_at_index( + env_id, + mask=np.logical_not(np.logical_or(terms, truncs)), + action_log_dist=self._get_action_log_dist(obs), + reward=rews, + obs=obs, + steps=len(rews), + task=tasks, + ) + + # Update task sampler + if env_id in self._rollouts.ready_buffers: + self._update_sampler(env_id) + + def _update_sampler(self, env_id): + if self._task_sampler.requires_value_buffers: + self._rollouts.compute_returns(self._gamma, self._gae_lambda, env_id) + self._task_sampler.update_with_rollouts(self._rollouts, env_id) + self._rollouts.after_update(env_id) + self._task_sampler.after_update() + + def _enumerate_tasks(self, space): + assert isinstance(space, Discrete) or isinstance(space, MultiDiscrete), f"Unsupported task space {space}: Expected Discrete or MultiDiscrete" + if isinstance(space, Discrete): + return list(range(space.n)) + else: + return list(enumerate_axes(space.nvec)) + + def log_metrics(self, writer, step=None): + """ + Log the task distribution to the provided tensorboard writer. + """ + # super().log_metrics(writer, step) + metrics = self._task_sampler.metrics() + writer.add_scalar("curriculum/proportion_seen", metrics["proportion_seen"], step) + writer.add_scalar("curriculum/score", metrics["score"], step) + # for task in list(self.task_space.tasks)[:10]: + # writer.add_scalar(f"curriculum/task_{task - 1}_score", metrics["task_scores"][task - 1], step) + # writer.add_scalar(f"curriculum/task_{task - 1}_staleness", metrics["task_staleness"][task - 1], step) diff --git a/curricula/plr/task_sampler.py b/curricula/plr/task_sampler.py new file mode 100644 index 00000000..c1e97a18 --- /dev/null +++ b/curricula/plr/task_sampler.py @@ -0,0 +1,354 @@ +# Code heavily based on the original Prioritized Level Replay implementation from https://github.com/facebookresearch/level-replay +# If you use this code, please cite the above codebase and original PLR paper: https://arxiv.org/abs/2010.03934 +import gymnasium as gym +import numpy as np +import torch + + +class TaskSampler: + """ Task sampler for Prioritized Level Replay (PLR) + + Args: + tasks (list): List of tasks to sample from + action_space (gym.spaces.Space): Action space of the environment + num_actors (int): Number of actors/processes + strategy (str): Strategy for sampling tasks. One of "value_l1", "gae", "policy_entropy", "least_confidence", "min_margin", "one_step_td_error". + replay_schedule (str): Schedule for sampling replay levels. One of "fixed" or "proportionate". + score_transform (str): Transform to apply to task scores. One of "constant", "max", "eps_greedy", "rank", "power", "softmax". + temperature (float): Temperature for score transform. Increasing temperature makes the sampling distribution more uniform. + eps (float): Epsilon for eps-greedy score transform. + rho (float): Proportion of seen tasks before replay sampling is allowed. + nu (float): Probability of sampling a replay level if using a fixed replay_schedule. + alpha (float): Linear interpolation weight for score updates. 0.0 means only use old scores, 1.0 means only use new scores. + staleness_coef (float): Linear interpolation weight for task staleness vs. task score. 0.0 means only use task score, 1.0 means only use staleness. + staleness_transform (str): Transform to apply to task staleness. One of "constant", "max", "eps_greedy", "rank", "power", "softmax". + staleness_temperature (float): Temperature for staleness transform. Increasing temperature makes the sampling distribution more uniform. + """ + def __init__( + self, + tasks: list, + action_space: gym.spaces.Space = None, + num_actors: int = 1, + strategy: str = "value_l1", + replay_schedule: str = "proportionate", + score_transform: str = "rank", + temperature: float = 0.1, + eps: float = 0.05, + rho: float = 1.0, + nu: float = 0.5, + alpha: float = 1.0, + staleness_coef: float = 0.1, + staleness_transform: str = "power", + staleness_temperature: float = 1.0, + ): + self.action_space = action_space + self.tasks = tasks + self.num_tasks = len(self.tasks) + + self.strategy = strategy + self.replay_schedule = replay_schedule + self.score_transform = score_transform + self.temperature = temperature + self.eps = eps + self.rho = rho + self.nu = nu + self.alpha = float(alpha) + self.staleness_coef = staleness_coef + self.staleness_transform = staleness_transform + self.staleness_temperature = staleness_temperature + + self.unseen_task_weights = np.array([1.0] * self.num_tasks) + self.task_scores = np.array([0.0] * self.num_tasks, dtype=float) + self.partial_task_scores = np.zeros((num_actors, self.num_tasks), dtype=float) + self.partial_task_steps = np.zeros((num_actors, self.num_tasks), dtype=np.int64) + self.task_staleness = np.array([0.0] * self.num_tasks, dtype=float) + + self.next_task_index = 0 # Only used for sequential strategy + + # Logging metrics + self._last_score = 0.0 + + if not self.requires_value_buffers and self.action_space is None: + raise ValueError( + 'Must provide action space to PLR if using "policy_entropy", "least_confidence", or "min_margin" strategies' + ) + + def update_with_rollouts(self, rollouts, actor_id=None): + if self.strategy == "random": + return + + # Update with a RolloutStorage object + if self.strategy == "policy_entropy": + score_function = self._average_entropy + elif self.strategy == "least_confidence": + score_function = self._average_least_confidence + elif self.strategy == "min_margin": + score_function = self._average_min_margin + elif self.strategy == "gae": + score_function = self._average_gae + elif self.strategy == "value_l1": + score_function = self._average_value_l1 + elif self.strategy == "one_step_td_error": + score_function = self._one_step_td_error + else: + raise ValueError(f"Unsupported strategy, {self.strategy}") + + self._update_with_rollouts(rollouts, score_function, actor_index=actor_id) + + def update_task_score(self, actor_index, task_idx, score, num_steps): + score = self._partial_update_task_score(actor_index, task_idx, score, num_steps, done=True) + + self.unseen_task_weights[task_idx] = 0.0 # No longer unseen + + old_score = self.task_scores[task_idx] + self.task_scores[task_idx] = (1.0 - self.alpha) * old_score + self.alpha * score + + def _partial_update_task_score(self, actor_index, task_idx, score, num_steps, done=False): + partial_score = self.partial_task_scores[actor_index][task_idx] + partial_num_steps = self.partial_task_steps[actor_index][task_idx] + + running_num_steps = partial_num_steps + num_steps + merged_score = partial_score + (score - partial_score) * num_steps / float(running_num_steps) + if done: + self.partial_task_scores[actor_index][task_idx] = 0.0 # zero partial score, partial num_steps + self.partial_task_steps[actor_index][task_idx] = 0 + else: + self.partial_task_scores[actor_index][task_idx] = merged_score + self.partial_task_steps[actor_index][task_idx] = running_num_steps + + return merged_score + + def _average_entropy(self, **kwargs): + episode_logits = kwargs["episode_logits"] + num_actions = self.action_space.n + max_entropy = -(1.0 / num_actions) * np.log(1.0 / num_actions) * num_actions + + return (-torch.exp(episode_logits) * episode_logits).sum(-1).mean().item() / max_entropy + + def _average_least_confidence(self, **kwargs): + episode_logits = kwargs["episode_logits"] + return (1 - torch.exp(episode_logits.max(-1, keepdim=True)[0])).mean().item() + + def _average_min_margin(self, **kwargs): + episode_logits = kwargs["episode_logits"] + top2_confidence = torch.exp(episode_logits.topk(2, dim=-1)[0]) + return 1 - (top2_confidence[:, 0] - top2_confidence[:, 1]).mean().item() + + def _average_gae(self, **kwargs): + returns = kwargs["returns"] + value_preds = kwargs["value_preds"] + + advantages = returns - value_preds + + return advantages.mean().item() + + def _average_value_l1(self, **kwargs): + returns = kwargs["returns"] + value_preds = kwargs["value_preds"] + + advantages = returns - value_preds + + return advantages.abs().mean().item() + + def _one_step_td_error(self, **kwargs): + rewards = kwargs["rewards"] + value_preds = kwargs["value_preds"] + + max_t = len(rewards) + td_errors = (rewards[:-1] + value_preds[: max_t - 1] - value_preds[1:max_t]).abs() + assert not torch.isnan( + td_errors.abs().mean() + ), f"Got invalid values for 'rewards' or 'value_preds'. Check that reward length: {len(rewards)}" + return td_errors.abs().mean().item() + + @property + def requires_value_buffers(self): + return self.strategy in ["gae", "value_l1", "one_step_td_error"] + + def _update_with_rollouts(self, rollouts, score_function, actor_index=None): + tasks = rollouts.tasks + if not self.requires_value_buffers: + policy_logits = rollouts.action_log_dist + done = ~(rollouts.masks > 0) + total_steps, num_actors = rollouts.tasks.shape[:2] + + actors = [actor_index] if actor_index is not None else range(num_actors) + for actor_index in actors: + done_steps = done[:, actor_index].nonzero()[:total_steps, 0] + start_t = 0 + + for t in done_steps: + if not start_t < total_steps: + break + + if (t == 0): # if t is 0, then this done step caused a full update of previous last cycle + continue + + # If there is only 1 step, we can't calculate the one-step td error + if self.strategy == "one_step_td_error" and t - start_t <= 1: + continue + + task_idx_t = tasks[start_t, actor_index].item() + + # Store kwargs for score function + score_function_kwargs = {} + if self.requires_value_buffers: + score_function_kwargs["returns"] = rollouts.returns[start_t:t, actor_index] + score_function_kwargs["rewards"] = rollouts.rewards[start_t:t, actor_index] + score_function_kwargs["value_preds"] = rollouts.value_preds[start_t:t, actor_index] + else: + episode_logits = policy_logits[start_t:t, actor_index] + score_function_kwargs["episode_logits"] = torch.log_softmax(episode_logits, -1) + score = score_function(**score_function_kwargs) + num_steps = len(rollouts.tasks[start_t:t, actor_index]) + # TODO: Check that task_idx_t is correct + self.update_task_score(actor_index, task_idx_t, score, num_steps) + + start_t = t.item() + if start_t < total_steps: + # If there is only 1 step, we can't calculate the one-step td error + if self.strategy == "one_step_td_error" and start_t == total_steps - 1: + continue + # TODO: Check this too + task_idx_t = tasks[start_t, actor_index].item() + + # Store kwargs for score function + score_function_kwargs = {} + if self.requires_value_buffers: + score_function_kwargs["returns"] = rollouts.returns[start_t:, actor_index] + score_function_kwargs["rewards"] = rollouts.rewards[start_t:, actor_index] + score_function_kwargs["value_preds"] = rollouts.value_preds[start_t:, actor_index] + else: + episode_logits = policy_logits[start_t:, actor_index] + score_function_kwargs["episode_logits"] = torch.log_softmax(episode_logits, -1) + + score = score_function(**score_function_kwargs) + self._last_score = score + num_steps = len(rollouts.tasks[start_t:, actor_index]) + self._partial_update_task_score(actor_index, task_idx_t, score, num_steps) + + def after_update(self): + # Reset partial updates, since weights have changed, and thus logits are now stale + for actor_index in range(self.partial_task_scores.shape[0]): + for task_idx in range(self.partial_task_scores.shape[1]): + if self.partial_task_scores[actor_index][task_idx] != 0: + self.update_task_score(actor_index, task_idx, 0, 0) + self.partial_task_scores.fill(0) + self.partial_task_steps.fill(0) + + def _update_staleness(self, selected_idx): + if self.staleness_coef > 0: + self.task_staleness = self.task_staleness + 1 + self.task_staleness[selected_idx] = 0 + + def _sample_replay_level(self): + sample_weights = self.sample_weights() + if np.isclose(np.sum(sample_weights), 0): + sample_weights = np.ones_like(sample_weights, dtype=float) / len(sample_weights) + + task_idx = np.random.choice(range(self.num_tasks), 1, p=sample_weights)[0] + task = self.tasks[task_idx] + + self._update_staleness(task_idx) + + return task + + def _sample_unseen_level(self): + sample_weights = self.unseen_task_weights / self.unseen_task_weights.sum() + task_idx = np.random.choice(range(self.num_tasks), 1, p=sample_weights)[0] + task = self.tasks[task_idx] + + self._update_staleness(task_idx) + + return task + + def sample(self, strategy=None): + if not strategy: + strategy = self.strategy + + if strategy == "random": + task_idx = np.random.choice(range((self.num_tasks))) + task = self.tasks[task_idx] + return task + + if strategy == "sequential": + task_idx = self.next_task_index + self.next_task_index = (self.next_task_index + 1) % self.num_tasks + task = self.tasks[task_idx] + return task + + num_unseen = (self.unseen_task_weights > 0).sum() + proportion_seen = (self.num_tasks - num_unseen) / self.num_tasks + + if self.replay_schedule == "fixed": + if proportion_seen >= self.rho: + # Sample replay level with fixed prob = 1 - nu OR if all levels seen + if np.random.rand() > self.nu or not proportion_seen < 1.0: + return self._sample_replay_level() + + # Otherwise, sample a new level + return self._sample_unseen_level() + + elif self.replay_schedule == "proportionate": + if proportion_seen >= self.rho and np.random.rand() < proportion_seen: + return self._sample_replay_level() + else: + return self._sample_unseen_level() + else: + raise NotImplementedError(f"Unsupported replay schedule: {self.replay_schedule}. Must be 'fixed' or 'proportionate'.") + + def sample_weights(self): + weights = self._score_transform(self.score_transform, self.temperature, self.task_scores) + weights = weights * (1 - self.unseen_task_weights) # zero out unseen levels + z = np.sum(weights) + if z > 0: + weights /= z + + staleness_weights = 0 + if self.staleness_coef > 0: + staleness_weights = self._score_transform( + self.staleness_transform, + self.staleness_temperature, + self.task_staleness, + ) + staleness_weights = staleness_weights * (1 - self.unseen_task_weights) + z = np.sum(staleness_weights) + if z > 0: + staleness_weights /= z + weights = (1 - self.staleness_coef) * weights + self.staleness_coef * staleness_weights + return weights + + def _score_transform(self, transform, temperature, scores): + if transform == "constant": + weights = np.ones_like(scores) + if transform == "max": + weights = np.zeros_like(scores) + scores = scores[:] + scores[self.unseen_task_weights > 0] = -float("inf") # only argmax over seen levels + argmax = np.random.choice(np.flatnonzero(np.isclose(scores, scores.max()))) + weights[argmax] = 1.0 + elif transform == "eps_greedy": + weights = np.zeros_like(scores) + weights[scores.argmax()] = 1.0 - self.eps + weights += self.eps / self.num_tasks + elif transform == "rank": + temp = np.flip(scores.argsort()) + ranks = np.empty_like(temp) + ranks[temp] = np.arange(len(temp)) + 1 + weights = 1 / ranks ** (1.0 / temperature) + elif transform == "power": + eps = 0 if self.staleness_coef > 0 else 1e-3 + weights = (np.array(scores) + eps) ** (1.0 / temperature) + elif transform == "softmax": + weights = np.exp(np.array(scores) / temperature) + + return weights + + def metrics(self): + return { + "task_scores": self.task_scores, + "unseen_task_weights": self.unseen_task_weights, + "task_staleness": self.task_staleness, + "proportion_seen": (self.num_tasks - (self.unseen_task_weights > 0).sum()) / self.num_tasks, + "score": self._last_score, + } diff --git a/curricula/sequential.py b/curricula/sequential.py new file mode 100644 index 00000000..ec3b8b09 --- /dev/null +++ b/curricula/sequential.py @@ -0,0 +1,207 @@ +import re +import warnings +from typing import Any, Callable, List, Union + +from syllabus.core import Curriculum +from syllabus.curricula import NoopCurriculum, DomainRandomization +from syllabus.task_space import TaskSpace + + +class SequentialCurriculum(Curriculum): + REQUIRES_STEP_UPDATES = False + REQUIRES_EPISODE_UPDATES = True + REQUIRES_CENTRAL_UPDATES = False + + def __init__(self, curriculum_list: List[Curriculum], stopping_conditions: List[Any], *curriculum_args, **curriculum_kwargs): + super().__init__(*curriculum_args, **curriculum_kwargs) + assert len(curriculum_list) > 0, "Must provide at least one curriculum" + assert len(stopping_conditions) == len(curriculum_list) - 1, f"Stopping conditions must be one less than the number of curricula. Final curriculum is used for the remainder of training. Expected {len(curriculum_list) - 1}, got {len(stopping_conditions)}." + if len(curriculum_list) == 1: + warnings.warn("Your sequential curriculum only containes one element. Consider using that element directly instead.") + + self.curriculum_list = self._parse_curriculum_list(curriculum_list) + self.stopping_conditions = self._parse_stopping_conditions(stopping_conditions) + self._curriculum_index = 0 + + # Stopping metrics + self.n_steps = 0 + self.total_steps = 0 + self.n_episodes = 0 + self.total_episodes = 0 + self.n_tasks = 0 + self.total_tasks = 0 + self.episode_returns = [] + + def _parse_curriculum_list(self, curriculum_list: List[Curriculum]) -> List[Curriculum]: + """ Parse the curriculum list to ensure that all items are curricula. + Adds Curriculum objects directly. Wraps task space items in NoopCurriculum objects. + """ + parsed_list = [] + for item in curriculum_list: + if isinstance(item, Curriculum): + parsed_list.append(item) + elif isinstance(item, TaskSpace): + parsed_list.append(DomainRandomization(item)) + elif isinstance(item, list): + task_space = TaskSpace(len(item), item) + parsed_list.append(DomainRandomization(task_space)) + elif self.task_space.contains(item): + parsed_list.append(NoopCurriculum(item, self.task_space)) + else: + raise ValueError(f"Invalid curriculum item: {item}") + + return parsed_list + + def _parse_stopping_conditions(self, stopping_conditions: List[Any]) -> List[Any]: + """ Parse the stopping conditions to ensure that all items are integers. """ + parsed_list = [] + for item in stopping_conditions: + if isinstance(item, Callable): + parsed_list.append(item) + elif isinstance(item, str): + parsed_list.append(self._parse_condition_string(item)) + else: + raise ValueError(f"Invalid stopping condition: {item}") + + return parsed_list + + def _parse_condition_string(self, condition: str) -> Callable: + """ Parse a string condition to a callable function. """ + + # Parse composite conditions + if '|' in condition: + conditions = re.split(re.escape('|'), condition) + return lambda: any(self._parse_condition_string(cond)() for cond in conditions) + elif '&' in condition: + conditions = re.split(re.escape('&'), condition) + return lambda: all(self._parse_condition_string(cond)() for cond in conditions) + + clauses = re.split('(<=|>=|=|<|>)', condition) + + try: + metric, comparator, value = clauses + + if metric == "steps": + metric_fn = self._get_steps + elif metric == "total_steps": + metric_fn = self._get_total_steps + elif metric == "episodes": + metric_fn = self._get_episodes + elif metric == "total_episodes": + metric_fn = self._get_total_episodes + elif metric == "tasks": + metric_fn = self._get_tasks + elif metric == "total_tasks": + metric_fn = self._get_total_tasks + elif metric == "episode_return": + metric_fn = self._get_episode_return + else: + raise ValueError(f"Invalid metric name: {metric}") + + if comparator == '<': + return lambda: metric_fn() < float(value) + elif comparator == '>': + return lambda: metric_fn() > float(value) + elif comparator == '<=': + return lambda: metric_fn() <= float(value) + elif comparator == '>=': + return lambda: metric_fn() >= float(value) + elif comparator == '=': + return lambda: metric_fn() == float(value) + else: + raise ValueError(f"Invalid comparator: {comparator}") + except ValueError as e: + raise ValueError(f"Invalid condition string: {condition}") from e + + def _get_steps(self): + return self.n_steps + + def _get_total_steps(self): + return self.total_steps + + def _get_episodes(self): + return self.n_episodes + + def _get_total_episodes(self): + return self.total_episodes + + def _get_tasks(self): + return self.n_tasks + + def _get_total_tasks(self): + return self.total_tasks + + def _get_episode_return(self): + return sum(self.episode_returns) / len(self.episode_returns) if len(self.episode_returns) > 0 else 0 + + @property + def current_curriculum(self): + return self.curriculum_list[self._curriculum_index] + + @property + def requires_step_updates(self): + return any(map(lambda c: c.requires_step_updates, self.curriculum_list)) + + def _sample_distribution(self) -> List[float]: + """ + Return None to indicate that tasks are not drawn from a distribution. + """ + return None + + def sample(self, k: int = 1) -> Union[List, Any]: + """ + Choose the next k tasks from the list. + """ + curriculum = self.current_curriculum + tasks = curriculum.sample(k) + + # Recode tasks into environment task space + decoded_tasks = [curriculum.task_space.decode(task) for task in tasks] + recoded_tasks = [self.task_space.encode(task) for task in decoded_tasks] + + self.n_tasks += k + self.total_tasks += k + + # Check if we should move on to the next phase of the curriculum + self.check_stopping_conditions() + return recoded_tasks + + def update_on_episode(self, episode_return, episode_len, episode_task, env_id=None): + self.n_episodes += 1 + self.total_episodes += 1 + self.n_steps += episode_len + self.total_steps += episode_len + self.episode_returns.append(episode_return) + + # Update current curriculum + if self.current_curriculum.requires_episode_updates: + self.current_curriculum.update_on_episode(episode_return, episode_len, episode_task, env_id) + + def update_on_step(self, task, obs, rew, term, trunc, info, env_id=None): + if self.current_curriculum.requires_step_updates: + self.current_curriculum.update_on_step(task, obs, rew, term, trunc, info, env_id) + + def update_on_step_batch(self, step_results, env_id=None): + if self.current_curriculum.requires_step_updates: + self.current_curriculum.update_on_step_batch(step_results, env_id) + + def update_on_demand(self, metrics): + self.current_curriculum.update_on_demand(metrics) + + def update_task_progress(self, task, progress, env_id=None): + self.current_curriculum.update_task_progress(task, progress, env_id) + + def check_stopping_conditions(self): + if self._curriculum_index < len(self.stopping_conditions) and self.stopping_conditions[self._curriculum_index](): + self._curriculum_index += 1 + self.n_episodes = 0 + self.n_steps = 0 + self.episode_returns = [] + self.n_tasks = 0 + + def log_metrics(self, writer, step=None, log_full_dist=False): + # super().log_metrics(writer, step, log_full_dist) + writer.add_scalar("curriculum/current_stage", self._curriculum_index, step) + writer.add_scalar("curriculum/steps", self.n_steps, step) + writer.add_scalar("curriculum/episodes", self.n_episodes, step) + writer.add_scalar("curriculum/episode_returns", self._get_episode_return(), step) diff --git a/curricula/simple_box.py b/curricula/simple_box.py new file mode 100644 index 00000000..97ed6e7e --- /dev/null +++ b/curricula/simple_box.py @@ -0,0 +1,63 @@ +import typing +from typing import Any, List, Union + +from gymnasium.spaces import Box + +from syllabus.core import Curriculum + + +class SimpleBoxCurriculum(Curriculum): + """ + Base class and API for defining curricula to interface with Gym environments. + """ + REQUIRES_STEP_UPDATES = False + REQUIRES_EPISODE_UPDATES = False + REQUIRES_CENTRAL_UPDATES = False + + def __init__(self, + *curriculum_args, + steps: int = 5, + success_threshold: float = 0.25, + required_successes: int = 10, + **curriculum_kwargs): + super().__init__(*curriculum_args, **curriculum_kwargs) + assert isinstance(self.task_space.gym_space, Box), "SimpleBoxCurriculum only supports Box task spaces." + + self.success_threshold = success_threshold + self.required_successes = required_successes + + full_range = self.task_space.gym_space.high[1] - self.task_space.gym_space.low[0] + midpoint = self.task_space.gym_space.low[0] + (full_range / 2.0) + self.step_size = (full_range / 2.0) / steps + self.max_range = (midpoint - self.step_size, midpoint + self.step_size) + self.consecutive_successes = 0 + self.max_reached = False + + def update_task_progress(self, task: typing.Any, success_prob: float, env_id: int = None) -> None: + """ + Update the curriculum with a task and its success probability upon + success or failure. + """ + if self.max_reached: + return + + # Check if this task passed success threshold + if success_prob > self.success_threshold: + self.consecutive_successes += 1 + else: + self.consecutive_successes = 0 + + # If we have enough successes in a row, update task + if self.consecutive_successes >= self.required_successes: + new_low = max(self.max_range[0] - self.step_size, self.task_space.gym_space.low[0]) + new_high = min(self.max_range[1] + self.step_size, self.task_space.gym_space.high[1]) + self.max_range = (new_low, new_high) + self.consecutive_successes = 0 + if new_low == self.task_space.gym_space.low[0] and new_high == self.task_space.gym_space.high[1]: + self.max_reached = True + + def sample(self, k: int = 1) -> Union[List, Any]: + """ + Sample k tasks from the curriculum. + """ + return [self.max_range for _ in range(k)] diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/experimental/README b/examples/experimental/README new file mode 100644 index 00000000..1302681f --- /dev/null +++ b/examples/experimental/README @@ -0,0 +1 @@ +The training scripts in this folder are unlikely to run without errors, but may serve as useful references if you want to use Syllabus in a new RL library or environment. \ No newline at end of file diff --git a/examples/experimental/cleanrl_cartpole.py b/examples/experimental/cleanrl_cartpole.py new file mode 100644 index 00000000..ab71e478 --- /dev/null +++ b/examples/experimental/cleanrl_cartpole.py @@ -0,0 +1,326 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppopy + +import argparse +import os +import random +import time +from distutils.util import strtobool + +import gymnasium as gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +# Syllabus imports +from syllabus.core import MultiProcessingSyncWrapper, make_multiprocessing_curriculum +from syllabus.curricula import SimpleBoxCurriculum +from syllabus.examples.task_wrappers import CartPoleTaskWrapper +from syllabus.task_space import TaskSpace +from torch.distributions.categorical import Categorical +from torch.utils.tensorboard import SummaryWriter + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="Syllabus", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to capture videos of the agent performances (check out `videos` folder)") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="CartPole-v1", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=500000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=2.5e-4, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=4, + help="the number of parallel game environments") + parser.add_argument("--num-steps", type=int, default=128, + help="the number of steps to run in each environment per policy rollout") + parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--gamma", type=float, default=0.99, + help="the discount factor gamma") + parser.add_argument("--gae-lambda", type=float, default=0.95, + help="the lambda for the general advantage estimation") + parser.add_argument("--num-minibatches", type=int, default=4, + help="the number of mini-batches") + parser.add_argument("--update-epochs", type=int, default=4, + help="the K epochs to update the policy") + parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles advantages normalization") + parser.add_argument("--clip-coef", type=float, default=0.2, + help="the surrogate clipping coefficient") + parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--ent-coef", type=float, default=0.01, + help="coefficient of the entropy") + parser.add_argument("--vf-coef", type=float, default=0.5, + help="coefficient of the value function") + parser.add_argument("--max-grad-norm", type=float, default=0.5, + help="the maximum norm for the gradient clipping") + parser.add_argument("--target-kl", type=float, default=None, + help="the target KL divergence threshold") + args = parser.parse_args() + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + # fmt: on + return args + + +def make_env(env_id, seed, idx, capture_video, run_name, task_queue, update_queue): + def thunk(): + env = gym.make(env_id) + env = gym.wrappers.RecordEpisodeStatistics(env) + if capture_video: + if idx == 0: + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + # env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + env = CartPoleTaskWrapper(env) + env = MultiProcessingSyncWrapper(env, task_queue, update_queue, + task_space=TaskSpace(gym.spaces.Box(-0.3, 0.3, shape=(2,))), + update_on_step=False) + return env + + return thunk + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + +class Agent(nn.Module): + def __init__(self, envs): + super().__init__() + self.critic = nn.Sequential( + layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 1), std=1.0), + ) + self.actor = nn.Sequential( + layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 64)), + nn.Tanh(), + layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01), + ) + + def get_value(self, x): + return self.critic(x) + + def get_action_and_value(self, x, action=None): + logits = self.actor(x) + probs = Categorical(logits=logits) + if action is None: + action = probs.sample() + return action, probs.log_prob(action), probs.entropy(), self.critic(x) + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # curriculum setup + curriculum = SimpleBoxCurriculum(TaskSpace(gym.spaces.Box(-0.3, 0.3, shape=(2,)))) + curriculum, task_queue, update_queue = make_multiprocessing_curriculum(curriculum) + + # env setup + envs = gym.vector.AsyncVectorEnv( + [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name, task_queue, update_queue) for i in range(args.num_envs)] + ) + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + + agent = Agent(envs).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs, _ = envs.reset() + next_obs = torch.Tensor(next_obs).to(device) + next_done = torch.zeros(args.num_envs).to(device) + num_updates = args.total_timesteps // args.batch_size + + for update in range(1, num_updates + 1): + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (update - 1.0) / num_updates + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += 1 * args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value = agent.get_action_and_value(next_obs) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, term, trunc, infos = envs.step(action.cpu().numpy()) + done = np.logical_or(term, trunc) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) + + if "final_info" in infos: + for info in infos["final_info"]: + if info and "episode" in info: + print(f"global_step={global_step}, episodic_return={info['episode']['r']}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + curriculum.log_metrics(writer, step=global_step) + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None: + if approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + envs.close() + writer.close() diff --git a/examples/experimental/cleanrl_minigrid_plr.py b/examples/experimental/cleanrl_minigrid_plr.py new file mode 100644 index 00000000..83454aa9 --- /dev/null +++ b/examples/experimental/cleanrl_minigrid_plr.py @@ -0,0 +1,353 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppopy + +import argparse +import os +import random +import time +from distutils.util import strtobool + +import gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from syllabus.core import (MultiProcessingSyncWrapper, TaskWrapper, + make_multiprocessing_curriculum) +from syllabus.curricula import CentralizedPrioritizedLevelReplay +from syllabus.examples.models import MinigridAgent +from syllabus.examples.task_wrappers import MinigridTaskWrapper +from torch.utils.tensorboard import SummaryWriter + +from .vecenv import VecNormalize + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="syllabus", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to capture videos of the agent performances (check out `videos` folder)") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="MiniGrid-MultiRoom-N4-Random-v0", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=100000000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=7e-4, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=64, + help="the number of parallel game environments") + parser.add_argument("--num-steps", type=int, default=256, + help="the number of steps to run in each environment per policy rollout") + parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--gamma", type=float, default=0.999, + help="the discount factor gamma") + parser.add_argument("--gae-lambda", type=float, default=0.95, + help="the lambda for the general advantage estimation") + parser.add_argument("--num-minibatches", type=int, default=8, + help="the number of mini-batches") + parser.add_argument("--update-epochs", type=int, default=4, + help="the K epochs to update the policy") + parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles advantages normalization") + parser.add_argument("--clip-coef", type=float, default=0.2, + help="the surrogate clipping coefficient") + parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--ent-coef", type=float, default=0.01, + help="coefficient of the entropy") + parser.add_argument("--vf-coef", type=float, default=0.5, + help="coefficient of the value function") + parser.add_argument("--max-grad-norm", type=float, default=0.5, + help="the maximum norm for the gradient clipping") + parser.add_argument("--target-kl", type=float, default=None, + help="the target KL divergence threshold") + args = parser.parse_args() + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + # fmt: on + return args + + +def make_env(env_id, seed, idx, capture_video, run_name, task_queue, update_queue): + def thunk(): + env = gym.make(env_id) + #env = gym.wrappers.FlattenObservation(env) + env = gym.wrappers.RecordEpisodeStatistics(env) + env = MinigridTaskWrapper(env) + if capture_video: + if idx == 0: + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + # env = MultiProcessingSyncWrapper( + # env, + # task_queue, + # update_queue, + # update_on_step=False, + # default_task=0, + # task_space=env.task_space, + # ) + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + + return thunk + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" + % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # sample_env = gym.make(args.env_id) + # sample_env = MinigridTaskWrapper(sample_env) + # curriculum = PrioritizedLevelReplay( + # sample_env.task_space, task_sampler_kwargs_dict={"strategy":"one_step_td_error", "rho":0.01, "nu":0}, num_processes=args.num_envs, gamma=args.gamma, gae_lambda=args.gae_lambda, + # ) + # curriculum, task_queue, update_queue = make_multiprocessing_curriculum(curriculum) + # del sample_env + task_queue = update_queue = None + # env setup + envs = gym.vector.AsyncVectorEnv( + [ + make_env(args.env_id, args.seed + i, i, args.capture_video, run_name, task_queue, update_queue) + for i in range(args.num_envs) + ] + ) + envs = VecNormalize(envs, ob=False) + + assert isinstance( + envs.single_action_space, gym.spaces.Discrete + ), "only discrete action space is supported" + + agent = MinigridAgent(envs.single_observation_space.shape, envs.single_action_space.n).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros( + (args.num_steps, args.num_envs) + envs.single_observation_space.shape + ).to(device) + actions = torch.zeros( + (args.num_steps, args.num_envs) + envs.single_action_space.shape + ).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs = torch.Tensor(envs.reset()).to(device) + next_done = torch.zeros(args.num_envs).to(device) + num_updates = args.total_timesteps // args.batch_size + + for update in range(1, num_updates + 1): + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (update - 1.0) / num_updates + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += 1 * args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value, full_log_probs = agent.get_action_and_value( + next_obs, full_log_probs=True + ) + values[step] = value.flatten() + + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, done, info = envs.step(action.cpu().numpy()) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor( + done + ).to(device) + + for item in info: + if "episode" in item.keys(): + print( + f"global_step={global_step}, episodic_return={item['episode']['r']}" + ) + writer.add_scalar( + "charts/episodic_return", item["episode"]["r"], global_step + ) + writer.add_scalar( + "charts/episodic_length", item["episode"]["l"], global_step + ) + break + + # update = { + # "update_type": "on_demand", + # "metrics": { + # "action_log_dist": full_log_probs, + # "value": value, + # "next_value": agent.get_value(next_obs) + # if step == args.num_steps - 1 + # else None, + # "rew": reward, + # "masks": torch.Tensor(1 - done), + # "tasks": envs.get_attr("task"), + # }, + # } + # curriculum.update_curriculum(update) + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = ( + rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + ) + advantages[t] = lastgaelam = ( + delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + ) + returns = advantages + values + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value( + b_obs[mb_inds], b_actions.long()[mb_inds] + ) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [ + ((ratio - 1.0).abs() > args.clip_coef).float().mean().item() + ] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / ( + mb_advantages.std() + 1e-8 + ) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp( + ratio, 1 - args.clip_coef, 1 + args.clip_coef + ) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None: + if approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar( + "charts/learning_rate", optimizer.param_groups[0]["lr"], global_step + ) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar( + "charts/SPS", int(global_step / (time.time() - start_time)), global_step + ) + + envs.close() + writer.close() diff --git a/examples/experimental/cleanrl_minihack_plr.py b/examples/experimental/cleanrl_minihack_plr.py new file mode 100644 index 00000000..8f88ad5d --- /dev/null +++ b/examples/experimental/cleanrl_minihack_plr.py @@ -0,0 +1,358 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppopy + +import argparse +import os +import random +import time +from distutils.util import strtobool + +import gym +#import minigrid +import minihack +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from gym.envs.registration import register +from syllabus.core import (MultiProcessingSyncWrapper, TaskWrapper, + make_multiprocessing_curriculum) +from syllabus.curricula import CentralizedPrioritizedLevelReplay +from syllabus.examples import MinihackTaskWrapper +from torch.distributions.categorical import Categorical +from torch.utils.tensorboard import SummaryWriter + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="cleanRL", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to capture videos of the agent performances (check out `videos` folder)") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="MiniHack-River-v0", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=500000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=2.5e-4, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=4, + help="the number of parallel game environments") + parser.add_argument("--num-steps", type=int, default=128, + help="the number of steps to run in each environment per policy rollout") + parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--gamma", type=float, default=0.99, + help="the discount factor gamma") + parser.add_argument("--gae-lambda", type=float, default=0.95, + help="the lambda for the general advantage estimation") + parser.add_argument("--num-minibatches", type=int, default=4, + help="the number of mini-batches") + parser.add_argument("--update-epochs", type=int, default=4, + help="the K epochs to update the policy") + parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles advantages normalization") + parser.add_argument("--clip-coef", type=float, default=0.2, + help="the surrogate clipping coefficient") + parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--ent-coef", type=float, default=0.01, + help="coefficient of the entropy") + parser.add_argument("--vf-coef", type=float, default=0.5, + help="coefficient of the value function") + parser.add_argument("--max-grad-norm", type=float, default=0.5, + help="the maximum norm for the gradient clipping") + parser.add_argument("--target-kl", type=float, default=None, + help="the target KL divergence threshold") + args = parser.parse_args() + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + # fmt: on + return args + + +def make_env(env_id, seed, idx, capture_video, run_name, task_queue, update_queue): + def thunk(): + env = gym.make(env_id, observation_keys=("pixel", "glyphs")) + env = gym.wrappers.FlattenObservation(env) + env = gym.wrappers.RecordEpisodeStatistics(env) + env = MinihackTaskWrapper(env) + if capture_video: + if idx == 0: + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + env = MultiProcessingSyncWrapper(env, + task_queue, + update_queue, + update_on_step=False, + default_task=0, + task_space=gym.spaces.Discrete(1)) + #env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + + return thunk + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + +class Agent(nn.Module): + def __init__(self, envs): + super().__init__() + self.critic = nn.Sequential( + layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 1), std=1.0), + ) + self.actor = nn.Sequential( + layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 64)), + nn.Tanh(), + layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01), + ) + + def get_value(self, x): + return self.critic(x) + + def get_action_and_value(self, x, action=None, full_log_probs=False): + logits = self.actor(x) + probs = Categorical(logits=logits) + if action is None: + action = probs.sample() + if full_log_probs: + log_probs = torch.log(probs.probs) + return action, probs.log_prob(action), probs.entropy(), self.critic(x), log_probs + return action, probs.log_prob(action), probs.entropy(), self.critic(x) + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + sample_env = gym.make(args.env_id) + sample_env = MinihackTaskWrapper(sample_env) + + curriculum, task_queue, update_queue = make_multiprocessing_curriculum(CentralizedPrioritizedLevelReplay, + sample_env.task_space, + {"strategy": "one_step_td_error", + "rho": 0.01, + "nu": 0}, + action_space=sample_env.action_space, + num_steps=args.num_steps, + num_processes=args.num_envs, + gamma=args.gamma, + gae_lambda=args.gae_lambda, + random_start_tasks=0) + del sample_env + + # env setup + envs = gym.vector.AsyncVectorEnv( + [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name, task_queue, update_queue) for i in range(args.num_envs)] + ) + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + + agent = Agent(envs).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs = torch.Tensor(envs.reset()).to(device) + next_done = torch.zeros(args.num_envs).to(device) + num_updates = args.total_timesteps // args.batch_size + + for update in range(1, num_updates + 1): + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (update - 1.0) / num_updates + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += 1 * args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value, full_log_probs = agent.get_action_and_value(next_obs, full_log_probs=True) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, done, info = envs.step(action.cpu().numpy()) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) + + for item in info: + if "episode" in item.keys(): + print(f"global_step={global_step}, episodic_return={item['episode']['r']}") + writer.add_scalar("charts/episodic_return", item["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step) + break + + update = { + "update_type": "on_demand", + "metrics": { + "action_log_dist": full_log_probs, + "value": value, + "next_value": agent.get_value(next_obs) if step == args.num_steps - 1 else None, + "rew": reward, + "masks": torch.Tensor(1 - done), + "tasks": envs.get_attr("current_task"), + } + } + curriculum.update_curriculum(update) + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values + + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None: + if approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + envs.close() + writer.close() \ No newline at end of file diff --git a/examples/experimental/cleanrl_pettingzoo_pistonball_plr.py b/examples/experimental/cleanrl_pettingzoo_pistonball_plr.py new file mode 100644 index 00000000..2b416e21 --- /dev/null +++ b/examples/experimental/cleanrl_pettingzoo_pistonball_plr.py @@ -0,0 +1,334 @@ +"""Basic code which shows what it's like to run PPO on the Pistonball env using the parallel API, this code is inspired by CleanRL. + +This code is exceedingly basic, with no logging or weights saving. +The intention was for users to have a (relatively clean) ~200 line file to refer to when they want to design their own learning algorithm. + +Author: Jet (https://github.com/jjshoots) +""" + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from pettingzoo.butterfly import pistonball_v6 +from supersuit import color_reduction_v0, frame_stack_v1, resize_v1 +from syllabus.core import (PettingZooMultiProcessingSyncWrapper, TaskWrapper, + make_multiprocessing_curriculum) +from syllabus.curricula import CentralizedPrioritizedLevelReplay +from syllabus.examples import PistonballTaskWrapper +from torch.distributions.categorical import Categorical + + +class Agent(nn.Module): + def __init__(self, num_actions): + super().__init__() + + self.network = nn.Sequential( + self._layer_init(nn.Conv2d(4, 32, 3, padding=1)), + nn.MaxPool2d(2), + nn.ReLU(), + self._layer_init(nn.Conv2d(32, 64, 3, padding=1)), + nn.MaxPool2d(2), + nn.ReLU(), + self._layer_init(nn.Conv2d(64, 128, 3, padding=1)), + nn.MaxPool2d(2), + nn.ReLU(), + nn.Flatten(), + self._layer_init(nn.Linear(128 * 8 * 8, 512)), + nn.ReLU(), + ) + self.actor = self._layer_init(nn.Linear(512, num_actions), std=0.01) + self.critic = self._layer_init(nn.Linear(512, 1)) + + def _layer_init(self, layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + def get_value(self, x): + return self.critic(self.network(x / 255.0)) + + def get_action_and_value(self, x, action=None, full_log_probs=False): + hidden = self.network(x / 255.0) + logits = self.actor(hidden) + probs = Categorical(logits=logits) + if action is None: + action = probs.sample() + if full_log_probs: + log_probs = torch.log(probs.probs) + return action, probs.log_prob(action), probs.entropy(), self.critic(hidden), log_probs + return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) + + +def batchify_obs(obs, device): + """Converts PZ style observations to batch of torch arrays.""" + # convert to list of np arrays + obs = np.stack([obs[a] for a in obs], axis=0) + # transpose to be (batch, channel, height, width) + obs = obs.transpose(0, -1, 1, 2) + # convert to torch + obs = torch.tensor(obs).to(device) + + return obs + + +def batchify(x, device): + """Converts PZ style returns to batch of torch arrays.""" + # convert to list of np arrays + x = np.stack([x[a] for a in x], axis=0) + # convert to torch + x = torch.tensor(x).to(device) + + return x + + +def unbatchify(x, env): + """Converts np array to PZ style arguments.""" + x = x.cpu().numpy() + x = {a: x[i] for i, a in enumerate(env.possible_agents)} + + return x + + +if __name__ == "__main__": + """ALGO PARAMS""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ent_coef = 0.1 + vf_coef = 0.1 + clip_coef = 0.1 + gamma = 0.99 + batch_size = 32 + stack_size = 4 + frame_size = (64, 64) + max_cycles = 125 + total_episodes = 100 + + # PLR settings + num_steps = 128 + + """ CURRICULUM SETUP """ + sample_env = pistonball_v6.parallel_env( + continuous=False, max_cycles=max_cycles + ) + sample_env = PistonballTaskWrapper(sample_env) + + num_agents = len(sample_env.possible_agents) + action_space = sample_env.action_spaces[sample_env.possible_agents[0]] + num_actions = action_space.n + observation_size = sample_env.observation_space(sample_env.possible_agents[0]).shape + + curriculum, task_queue, update_queue = make_multiprocessing_curriculum(CentralizedPrioritizedLevelReplay, + sample_env.task_space, + task_sampler_kwargs_dict={"strategy": "one_step_td_error", + "rho": 0.01, + "nu": 0}, + action_space=action_space, + num_steps=num_steps, + num_processes=num_agents, + gamma=gamma, + gae_lambda=0.95, + random_start_tasks=0) + del sample_env + + """ ENV SETUP """ + env = pistonball_v6.parallel_env( + continuous=False, max_cycles=max_cycles + ) + env = PistonballTaskWrapper(env) + env = PettingZooMultiProcessingSyncWrapper(env, + task_queue, + update_queue, + update_on_step=False, + default_task=0, + task_space=env.task_space) + task_space = env.task_space + env = color_reduction_v0(env) + env = resize_v1(env, frame_size[0], frame_size[1]) + env = frame_stack_v1(env, stack_size=stack_size) + num_agents = len(env.possible_agents) + action_space = env.action_spaces[env.possible_agents[0]] + num_actions = action_space.n + observation_size = env.observation_space(env.possible_agents[0]).shape + + """ LEARNER SETUP """ + agent = Agent(num_actions=num_actions).to(device) + optimizer = optim.Adam(agent.parameters(), lr=0.001, eps=1e-5) + + """ ALGO LOGIC: EPISODE STORAGE""" + end_step = 0 + total_episodic_return = 0 + rb_obs = torch.zeros((max_cycles, num_agents, stack_size, *frame_size)).to(device) + rb_actions = torch.zeros((max_cycles, num_agents)).to(device) + rb_logprobs = torch.zeros((max_cycles, num_agents)).to(device) + rb_rewards = torch.zeros((max_cycles, num_agents)).to(device) + rb_dones = torch.zeros((max_cycles, num_agents)).to(device) + rb_values = torch.zeros((max_cycles, num_agents)).to(device) + + """ TRAINING LOGIC """ + # train for n number of episodes + global_cycles = 0 + for episode in range(total_episodes): + # collect an episode + with torch.no_grad(): + # collect observations and convert to batch of torch tensors + next_obs = env.reset() + # reset the episodic return + total_episodic_return = 0 + + # each episode has num_steps + for step in range(0, max_cycles): + global_cycles += 1 + # rollover the observation + obs = batchify_obs(next_obs, device) + + # get action from the agent + actions, logprobs, _, values, full_log_probs = agent.get_action_and_value(obs, full_log_probs=True) + + # execute the environment and log data + next_obs, rewards, dones, infos = env.step( + unbatchify(actions, env) + ) + + # add to episode storage + rb_obs[step] = obs + rb_rewards[step] = batchify(rewards, device) + rb_dones[step] = batchify(dones, device) + rb_actions[step] = actions + rb_logprobs[step] = logprobs + rb_values[step] = values.flatten() + + # compute episodic return + total_episodic_return += rb_rewards[step].cpu().numpy() + + # Update curriculum + if global_cycles % num_steps == 0: + update = { + "update_type": "on_demand", + "metrics": { + "action_log_dist": full_log_probs, + "value": values, + "next_value": agent.get_value(next_obs) if step == num_steps - 1 else None, + "rew": rb_rewards[step], + "masks": torch.Tensor(1 - np.array(list(dones.values()))), + "tasks": [env.unwrapped.task], + } + } + curriculum.update_curriculum(update) + + # if we reach the end of the episode + if any([dones[a] for a in dones]): + end_step = step + break + + + # bootstrap value if not done + with torch.no_grad(): + rb_advantages = torch.zeros_like(rb_rewards).to(device) + for t in reversed(range(end_step)): + delta = ( + rb_rewards[t] + + gamma * rb_values[t + 1] * rb_dones[t + 1] + - rb_values[t] + ) + rb_advantages[t] = delta + gamma * gamma * rb_advantages[t + 1] + rb_returns = rb_advantages + rb_values + + # convert our episodes to batch of individual transitions + b_obs = torch.flatten(rb_obs[:end_step], start_dim=0, end_dim=1) + b_logprobs = torch.flatten(rb_logprobs[:end_step], start_dim=0, end_dim=1) + b_actions = torch.flatten(rb_actions[:end_step], start_dim=0, end_dim=1) + b_returns = torch.flatten(rb_returns[:end_step], start_dim=0, end_dim=1) + b_values = torch.flatten(rb_values[:end_step], start_dim=0, end_dim=1) + b_advantages = torch.flatten(rb_advantages[:end_step], start_dim=0, end_dim=1) + + # Optimizing the policy and value network + b_index = np.arange(len(b_obs)) + clip_fracs = [] + for repeat in range(3): + # shuffle the indices we use to access the data + np.random.shuffle(b_index) + for start in range(0, len(b_obs), batch_size): + # select the indices we want to train on + end = start + batch_size + batch_index = b_index[start:end] + + _, newlogprob, entropy, value = agent.get_action_and_value( + b_obs[batch_index], b_actions.long()[batch_index] + ) + logratio = newlogprob - b_logprobs[batch_index] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clip_fracs += [ + ((ratio - 1.0).abs() > clip_coef).float().mean().item() + ] + + # normalize advantaegs + advantages = b_advantages[batch_index] + advantages = (advantages - advantages.mean()) / ( + advantages.std() + 1e-8 + ) + + # Policy loss + pg_loss1 = -b_advantages[batch_index] * ratio + pg_loss2 = -b_advantages[batch_index] * torch.clamp( + ratio, 1 - clip_coef, 1 + clip_coef + ) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + value = value.flatten() + v_loss_unclipped = (value - b_returns[batch_index]) ** 2 + v_clipped = b_values[batch_index] + torch.clamp( + value - b_values[batch_index], + -clip_coef, + clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[batch_index]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + + entropy_loss = entropy.mean() + loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + print(f"Training episode {episode}") + print(f"Episodic Return: {np.mean(total_episodic_return)}") + print(f"Episode Length: {end_step}") + print("") + print(f"Value Loss: {v_loss.item()}") + print(f"Policy Loss: {pg_loss.item()}") + print(f"Old Approx KL: {old_approx_kl.item()}") + print(f"Approx KL: {approx_kl.item()}") + print(f"Clip Fraction: {np.mean(clip_fracs)}") + print(f"Explained Variance: {explained_var.item()}") + print("\n-------------------------------------------\n") + + """ RENDER THE POLICY """ + env = pistonball_v6.parallel_env(continuous=False) + env = color_reduction_v0(env) + env = resize_v0(env, 64, 64) + env = frame_stack_v1(env, stack_size=4) + + agent.eval() + + with torch.no_grad(): + # render 5 episodes out + for episode in range(5): + obs = batchify_obs(env.reset(), device) + dones = [False] + while not any(dones): + actions, logprobs, _, values = agent.get_action_and_value(obs) + obs, rewards, dones, infos = env.step(unbatchify(actions, env)) + obs = batchify_obs(obs, device) + dones = [dones[a] for a in dones] diff --git a/examples/experimental/dormant_neurons.py b/examples/experimental/dormant_neurons.py new file mode 100644 index 00000000..ca2eb21c --- /dev/null +++ b/examples/experimental/dormant_neurons.py @@ -0,0 +1,764 @@ +"""This class implements recycling of dead neurons.""" +import functools +import logging + +import flax +import jax +import jax.numpy as jnp +import optax +from flax import linen as flax_nn +from jax import random + + +def leastk_mask(scores, ones_fraction): + """Given a tensor of scores creates a binary mask. + + Args: + scores: top-scores are kept + ones_fraction: float, of the generated mask. + + Returns: + array, same shape and type as scores or None. + """ + if ones_fraction is None or ones_fraction == 0: + return jnp.zeros_like(scores) + # This is to ensure indices with smallest values are selected. + scores = -scores + + n_ones = jnp.round(jnp.size(scores) * ones_fraction) + k = jnp.maximum(1, n_ones).astype(int) + flat_scores = jnp.reshape(scores, -1) + threshold = jax.lax.sort(flat_scores)[-k] + + mask = (flat_scores >= threshold).astype(flat_scores.dtype) + return mask.reshape(scores.shape) + + +def reset_momentum(momentum, mask): + new_momentum = momentum if mask is None else momentum * (1.0 - mask) + return new_momentum + + +def weight_reinit_zero(param, mask): + if mask is None: + return param + else: + new_param = jnp.zeros_like(param) + param = jnp.where(mask == 1, new_param, param) + return param + + +def weight_reinit_random(param, mask, key, weight_scaling=False, scale=1.0, weights_type='incoming'): + """Randomly reinit recycled weights and may scale its norm. + + If scaling applied, the norm of recycled weights equals + the average norm of non recycled weights per neuron multiplied by a scalar. + + Args: + param: current param + mask: incoming/outgoing mask for recycled weights + key: random key to generate new random weights + weight_scaling: if true scale recycled weights with the norm of non recycled + scale: scale to multiply the new weights norm. + weights_type: incoming or outgoing weights + + Returns: + params: new params after weight recycle. + """ + if mask is None or key is None: + return param + + new_param = flax_nn.initializers.xavier_uniform()(key, shape=param.shape) + + if weight_scaling: + axes = list(range(param.ndim)) + if weights_type == 'outgoing': + del axes[-2] + else: + del axes[-1] + + neuron_mask = jnp.mean(mask, axis=axes) + + non_dead_count = neuron_mask.shape[0] - jnp.count_nonzero(neuron_mask) + norm_per_neuron = _get_norm_per_neuron(param, axes) + non_recycled_norm = jnp.sum(norm_per_neuron * + (1 - neuron_mask)) / non_dead_count + non_recycled_norm = non_recycled_norm * scale + + normalized_new_param = _weight_normalization_per_neuron_norm( + new_param, axes) + new_param = normalized_new_param * non_recycled_norm + + param = jnp.where(mask == 1, new_param, param) + return param + + +def _weight_normalization_per_neuron_norm(param, axes): + norm_per_neuron = _get_norm_per_neuron(param, axes) + norm_per_neuron = jnp.expand_dims(norm_per_neuron, axis=axes) + normalized_param = param / norm_per_neuron + return normalized_param + + +def _get_norm_per_neuron(param, axes): + return jnp.sqrt(jnp.sum(jnp.power(param, 2), axis=axes)) + + +class BaseRecycler(): + """Base class for weight update methods. + + Attributes: + all_layers_names: list of layer names in a model. + recycle_type: neuron, layer based. + dead_neurons_threshold: below this threshold a neuron is considered dead. + reset_layers: list of layer names to be recycled. + reset_start_layer_idx: index of the layer from which we start recycling. + reset_period: int represents the period of weight update. + reset_start_step: start recycle from start step + reset_end_step: end recycle from end step + logging_period: the period of statistics logging e.g., dead neurons. + prev_neuron_score: score at last reset step or log step in case of no reset. + sub_mean_score: if True the average activation will be subtracted + for each neuron when we calculate the score. + """ + + def __init__( + self, + all_layers_names, + dead_neurons_threshold=0.0, + reset_start_layer_idx=0, + reset_period=200_000, + reset_start_step=0, + reset_end_step=100_000_000, + logging_period=20_000, + sub_mean_score=False, + ): + self.all_layers_names = all_layers_names + self.dead_neurons_threshold = dead_neurons_threshold + self.reset_layers = all_layers_names[reset_start_layer_idx:] + self.reset_period = reset_period + self.reset_start_step = reset_start_step + self.reset_end_step = reset_end_step + self.logging_period = logging_period + self.prev_neuron_score = None + self.sub_mean_score = sub_mean_score + + def update_reset_layers(self, reset_start_layer_idx): + self.reset_layers = self.all_layers_names[reset_start_layer_idx:] + + def is_update_iter(self, step): + return step > 0 and (step % self.reset_period == 0) + + def update_weights(self, intermediates, params, key, opt_state): + raise NotImplementedError + + def maybe_update_weights(self, update_step, intermediates, params, key, opt_state): + self._last_update_step = update_step + if self.is_reset(update_step): + new_params, new_opt_state = self.update_weights(intermediates, params, key, opt_state) + else: + new_params, new_opt_state = params, opt_state + return new_params, new_opt_state + + def is_reset(self, update_step): + del update_step + return False + + def is_intermediated_required(self, update_step): + return self.is_logging_step(update_step) + + def is_logging_step(self, step): + return step % self.logging_period == 0 + + def maybe_log_deadneurons(self, update_step, intermediates): + is_logging = self.is_logging_step(update_step) + if is_logging: + return self.log_dead_neurons_count(intermediates) + else: + return None + + def intersected_dead_neurons_with_last_reset(self, intermediates, update_step): + if self.is_logging_step(update_step): + log_dict = self.log_intersected_dead_neurons(intermediates) + return log_dict + else: + return None + + def log_intersected_dead_neurons(self, intermediates): + """Track intersected dead neurons with last logging/reset step. + + Args: + intermediates: current intermediates + + Returns: + log_dict: dict contains the percentage of intersection + """ + score_tree = jax.tree_map(self.estimate_neuron_score, intermediates) + neuron_score_dict = flax.traverse_util.flatten_dict(score_tree, sep='/') + + if self.prev_neuron_score is None: + self.prev_neuron_score = neuron_score_dict + log_dict = None + else: + log_dict = {} + for prev_k_score, current_k_score in zip(self.prev_neuron_score.items(), + neuron_score_dict.items()): + _, prev_score = prev_k_score + k, score = current_k_score + prev_score, score = prev_score[0], score[0] + prev_mask = prev_score <= self.dead_neurons_threshold + # we count the dead neurons which remains dead in the current step. + intersected_mask = (prev_mask) & (score <= self.dead_neurons_threshold) + prev_dead_count = jnp.count_nonzero(prev_mask) + intersected_count = jnp.count_nonzero(intersected_mask) + + percent = (float(intersected_count) / + prev_dead_count) if prev_dead_count else 0.0 + log_dict[f'dead_intersected_percent/{k[:-9]}'] = float(percent) * 100. + + # track average score of recycled neurons from last step and + # the average of non dead in the current step. + nondead_mask = score > self.dead_neurons_threshold + log_dict[f'mean_score_recycled/{k[:-9]}'] = float( + jnp.mean(score[prev_mask])) + log_dict[f'mean_score_nondead/{k[:-9]}'] = float( + jnp.mean(score[nondead_mask])) + + self.prev_neuron_score = neuron_score_dict + return log_dict + + def log_dead_neurons_count(self, intermediates): + """log dead neurons in each layer. + + For conv layer we also log dead elements in the spatial dimension. + + Args: + intermediates: intermidate activation in each layer. + + Returns: + log_dict_elements_per_neuron + log_dict_neurons + """ + + def log_dict(score, score_type): + total_neurons, total_deadneurons = 0., 0. + score_dict = flax.traverse_util.flatten_dict(score, sep='/') + + log_dict = {} + for k, m in score_dict.items(): + if 'final_layer' in k: + continue + m = m[0] + layer_size = float(jnp.size(m)) + deadneurons_count = jnp.count_nonzero(m <= self.dead_neurons_threshold) + total_neurons += layer_size + total_deadneurons += deadneurons_count + log_dict[f'dead_{score_type}_percentage/{k[:-9]}'] = ( + float(deadneurons_count) / layer_size) * 100. + log_dict[f'dead_{score_type}_count/{k[:-9]}'] = float(deadneurons_count) + log_dict[f'{score_type}/total'] = total_neurons + log_dict[f'{score_type}/deadcount'] = float(total_deadneurons) + log_dict[f'dead_{score_type}_percentage'] = (float(total_deadneurons) / + total_neurons) * 100. + return log_dict + + neuron_score = jax.tree_map(self.estimate_neuron_score, intermediates) + log_dict_neurons = log_dict(neuron_score, 'feature') + + return log_dict_neurons + + def estimate_neuron_score(self, activation, is_cbp=False): + """Calculates neuron score based on absolute value of activation. + + The score of feature map is the normalized average score over + the spatial dimension. + + Args: + activation: intermediate activation of each layer + is_cbp: if true, subtracts the mean and skips normalization. + + Returns: + element_score: score of each element in feature map in the spatial dim. + neuron_score: score of feature map + """ + reduce_axes = list(range(activation.ndim - 1)) + if self.sub_mean_score or is_cbp: + activation = activation - jnp.mean(activation, axis=reduce_axes) + + score = jnp.mean(jnp.abs(activation), axis=reduce_axes) + if not is_cbp: + # Normalize so that all scores sum to one. + score /= jnp.mean(score) + 1e-9 + + return score + + +class NeuronRecycler(BaseRecycler): + """Recycle the weights connected to dead neurons. + + In convolutional neural networks, we consider a feature map as neuron. + + Attributes: + next_layers: dict key a current layer name, value next layer name. + init_method_outgoing: method to init outgoing weights (random, zero). + weight_scaling: if true, scale reinit weights. + incoming_scale: scalar for incoming weights. + outgoing_scale: scalar for outgoing weights. + """ + + def __init__(self, + all_layers_names, + init_method_outgoing='zero', + weight_scaling=False, + incoming_scale=1.0, + outgoing_scale=1.0, + network='nature', + **kwargs): + super(NeuronRecycler, self).__init__(all_layers_names, **kwargs) + self.init_method_outgoing = init_method_outgoing + self.weight_scaling = weight_scaling + self.incoming_scale = incoming_scale + self.outgoing_scale = outgoing_scale + # prepare a dict that has pointer to next layer give a layer name + # this is needed because neuron recycle reinitalizes both sides + # (incoming and outgoing weights) of a neuron and needs a point to the + # outgoing weights. + self.next_layers = {} + for current_layer, next_layer in zip(all_layers_names[:-1], + all_layers_names[1:]): + self.next_layers[current_layer] = next_layer + + # we don't recycle the neurons in the output layer. + self.reset_layers = self.reset_layers[:-1] + + # if network is resnet, recycle only the incoming/outgoing of the first conv + # layer in each block and final dense layer + if network == 'resnet': + self.reset_layers = [] + for layer in self.all_layers_names: + if 'Conv_1' in layer or 'Conv_3' in layer or 'Dense' in layer: + self.reset_layers.append(layer) + + def intersected_dead_neurons_with_last_reset(self, intermediates, + update_step): + if self.is_reset(update_step): + log_dict = self.log_intersected_dead_neurons(intermediates) + return log_dict + else: + return None + + def is_reset(self, update_step): + within_reset_interval = ( + update_step >= self.reset_start_step and + update_step < self.reset_end_step) + return self.is_update_iter(update_step) and within_reset_interval + + def is_intermediated_required(self, update_step): + is_logging = self.is_logging_step(update_step) + is_update_iter = self.is_update_iter(update_step) + return is_logging or is_update_iter + + def update_reset_layers(self, reset_start_layer_idx): + self.reset_layers = self.all_layers_names[reset_start_layer_idx:] + self.reset_layers = self.reset_layers[:-1] + + def update_weights(self, intermediates, params, key, opt_state): + new_param, opt_state = self.recycle_dead_neurons( + intermediates, params, key, opt_state + ) + return new_param, opt_state + + def recycle_dead_neurons(self, intermedieates, params, key, opt_state): + """Recycle dead neurons by reinitalizie incoming and outgoing connections. + + Incoming connections are randomly initalized and outgoing connections + are zero initalized. + A featuremap is considered dead when its score is below or equal + dead neuron threshold. + Args: + intermedieates: pytree contains the activations over a batch. + params: current weights of the model. + key: used to generate random keys. + opt_state: state of optimizer. + + Returns: + new model params after recycling dead neurons. + opt_state: new state for the optimizer + + Raises: raise error if init_method_outgoing is not one of the following + (random, zero). + """ + activations_score_dict = flax.traverse_util.flatten_dict( + intermedieates, sep='/' + ) + param_dict = flax.traverse_util.flatten_dict(params, sep='/') + + # create incoming and outgoing masks and reset bias of dead neurons. + ( + incoming_mask_dict, + outgoing_mask_dict, + incoming_random_keys_dict, + outgoing_random_keys_dict, + param_dict, + ) = self.create_masks(param_dict, activations_score_dict, key) + + params = flax.core.freeze( + flax.traverse_util.unflatten_dict(param_dict, sep='/')) + incoming_random_keys = flax.core.freeze( + flax.traverse_util.unflatten_dict(incoming_random_keys_dict, sep='/')) + if self.init_method_outgoing == 'random': + outgoing_random_keys = flax.core.freeze( + flax.traverse_util.unflatten_dict(outgoing_random_keys_dict, sep='/')) + # reset incoming weights + incoming_mask = flax.core.freeze( + flax.traverse_util.unflatten_dict(incoming_mask_dict, sep='/')) + reinit_fn = functools.partial( + weight_reinit_random, + weight_scaling=self.weight_scaling, + scale=self.incoming_scale, + weights_type='incoming') + weight_random_reset_fn = jax.jit(functools.partial(jax.tree_map, reinit_fn)) + params = weight_random_reset_fn(params, incoming_mask, incoming_random_keys) + + # reset outgoing weights + outgoing_mask = flax.core.freeze( + flax.traverse_util.unflatten_dict(outgoing_mask_dict, sep='/')) + if self.init_method_outgoing == 'random': + reinit_fn = functools.partial( + weight_reinit_random, + weight_scaling=self.weight_scaling, + scale=self.outgoing_scale, + weights_type='outgoing') + weight_random_reset_fn = jax.jit( + functools.partial(jax.tree_map, reinit_fn)) + params = weight_random_reset_fn(params, outgoing_mask, + outgoing_random_keys) + elif self.init_method_outgoing == 'zero': + weight_zero_reset_fn = jax.jit( + functools.partial(jax.tree_map, weight_reinit_zero)) + params = weight_zero_reset_fn(params, outgoing_mask) + else: + raise ValueError(f'Invalid init method: {self.init_method_outgoing}') + # reset mu, nu of adam optimizer for recycled weights. + reset_momentum_fn = jax.jit(functools.partial(jax.tree_map, reset_momentum)) + new_mu = reset_momentum_fn(opt_state[0][1], incoming_mask) + new_mu = reset_momentum_fn(new_mu, outgoing_mask) + new_nu = reset_momentum_fn(opt_state[0][2], incoming_mask) + new_nu = reset_momentum_fn(new_nu, outgoing_mask) + opt_state_list = list(opt_state) + opt_state_list[0] = optax.ScaleByAdamState( + opt_state[0].count, mu=new_mu, nu=new_nu) + opt_state = tuple(opt_state_list) + return params, opt_state + + def _score2mask(self, activation, param, next_param, key): + del key, param, next_param + score = self.estimate_neuron_score(activation) + return score <= self.dead_neurons_threshold + + def create_masks(self, param_dict, activations_dict, key): + """create the masks for recycled weights based on neurons scores. + + Args: + param_dict: dict of model params. + activations_dict: dict of the neuron score of each layer. + key: used seed for random weights. + + Returns: + incoming_mask_dict + outgoing_mask_dict + ingoing_random_keys_dict + outgoing_random_keys_dict + param_dict + """ + incoming_mask_dict = { + k: jnp.zeros_like(p) if p.ndim != 1 else None + for k, p in param_dict.items() + } + outgoing_mask_dict = { + k: jnp.zeros_like(p) if p.ndim != 1 else None + for k, p in param_dict.items() + } + ingoing_random_keys_dict = {k: None for k in param_dict} + outgoing_random_keys_dict = { + k: None for k in param_dict + } if self.init_method_outgoing == 'random' else {} + + # prepare mask of incoming and outgoing recycled connections + for k in self.reset_layers: + param_key = 'params/' + k + '/kernel' + param = param_dict[param_key] + # This won't work for DRQ, since returned keys can be a list. + # We don't support that at the moment. + next_key = self.next_layers[k] + if isinstance(next_key, list): + next_key = next_key[0] + next_param = param_dict['params/' + next_key + '/kernel'] + activation = activations_dict[k + '_act/__call__'][0] + # TODO(evcu) Maybe use per_layer random keys here. + neuron_mask = self._score2mask(activation, param, next_param, key) + + # the for loop handles the case where a layer has multiple next layers + # like the case in DrQ where the output layer has multihead. + next_keys = ( + self.next_layers[k] + if isinstance(self.next_layers[k], list) else [self.next_layers[k]]) + for next_k in next_keys: + next_param_key = 'params/' + next_k + '/kernel' + next_param = param_dict[next_param_key] + incoming_mask, outgoing_mask = self.create_mask_helper( + neuron_mask, param, next_param) + incoming_mask_dict[param_key] = incoming_mask + outgoing_mask_dict[next_param_key] = outgoing_mask + key, subkey = random.split(key) + ingoing_random_keys_dict[param_key] = subkey + if self.init_method_outgoing == 'random': + key, subkey = random.split(key) + outgoing_random_keys_dict[next_param_key] = subkey + + # reset bias + bias_key = 'params/' + k + '/bias' + new_bias = jnp.zeros_like(param_dict[bias_key]) + param_dict[bias_key] = jnp.where(neuron_mask, new_bias, + param_dict[bias_key]) + + return (incoming_mask_dict, outgoing_mask_dict, ingoing_random_keys_dict, + outgoing_random_keys_dict, param_dict) + + def create_mask_helper(self, neuron_mask, current_param, next_param): + """generate incoming and outgoing weight mask given dead neurons mask. + + Args: + neuron_mask: mask of size equals the width of a layer. + current_param: incoming weights of a layer. + next_param: outgoing weights of a layer. + + Returns: + incoming_mask + outgoing_mask + """ + + def mask_creator(expansion_axis, expansion_axes, param, neuron_mask): + """create a mask of weight matrix given 1D vector of neurons mask. + + Args: + expansion_axis: List contains 1 axis. The dimension to expand the mask + for dense layers (weight shape 2D). + expansion_axes: List conrtains 3 axes. The dimensions to expand the + score for convolutional layers (weight shape 4D). + param: weight. + neuron_mask: 1D mask that represents dead neurons(features). + + Returns: + mask: mask of weight. + """ + if param.ndim == 2: + axes = expansion_axis + # flatten layer + # The size of neuron_mask is the same as the width of last conv layer. + # This conv layer will be flatten and connected to dense layer. + # we repeat each value of a feature map to cover the spatial dimension. + if axes[0] == 1 and (param.shape[0] > neuron_mask.shape[0]): + num_repeatition = int(param.shape[0] / neuron_mask.shape[0]) + neuron_mask = jnp.repeat(neuron_mask, num_repeatition, axis=0) + elif param.ndim == 4: + axes = expansion_axes + mask = jnp.expand_dims(neuron_mask, axis=tuple(axes)) + for i in range(len(axes)): + mask = jnp.repeat(mask, param.shape[axes[i]], axis=axes[i]) + return mask + + incoming_mask = mask_creator([0], [0, 1, 2], current_param, neuron_mask) + outgoing_mask = mask_creator([1], [0, 1, 3], next_param, neuron_mask) + return incoming_mask, outgoing_mask + + +class NeuronRecyclerScheduled(NeuronRecycler): + """Fixed scheduled version of the NeuronRecycler.""" + + def __init__( + self, + *args, + score_type='redo', + recycle_rate=0.3, + **kwargs, + ): + super(NeuronRecyclerScheduled, self).__init__(*args, **kwargs) + self.score_type = score_type + self.recycle_rate = recycle_rate + + def _score2mask(self, activation, param, next_param, key): + is_cbp = self.score_type == 'cbp' + score = self.estimate_neuron_score(activation, is_cbp=is_cbp) + if self.score_type == 'redo': + pass + elif self.score_type == 'random': + new_key = random.fold_in(key, self._last_update_step) + score = random.shuffle(new_key, score) + elif self.score_type == 'redo_inverted': + score = -score + # Metric used in Continual Backprop pape. + elif self.score_type == 'cbp': + next_axes = list(range(param.ndim)) + del next_axes[-2] + current_axes = list(range(param.ndim)) + del current_axes[-1] + if next_param.ndim == 2 and param.ndim == 4: + new_shape = activation.shape[1:] + (-1,) + next_param = jnp.reshape(next_param, new_shape) + score *= jnp.sum(jnp.abs(next_param), axis=next_axes) / jnp.sum( + jnp.abs(param), axis=current_axes + ) + multiplier = max(0, self._last_update_step / self.reset_end_step) + ones_fraction = float(jnp.cos(jnp.pi * 0.5 * multiplier)) + ones_fraction *= self.recycle_rate + return leastk_mask(score, ones_fraction) + + + +if __name__ == "__main__": + import gym + import numpy as np + import procgen + import torch + import torch.nn as nn + import torch.optim as optim + from torch.distributions.categorical import Categorical + + def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + # taken from https://github.com/AIcrowd/neurips2020-procgen-starter-kit/blob/142d09586d2272a17f44481a115c4bd817cf6a94/models/impala_cnn_torch.py + class ResidualBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv0 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1) + self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1) + + def forward(self, x): + inputs = x + x = nn.functional.relu(x) + x = self.conv0(x) + x = nn.functional.relu(x) + x = self.conv1(x) + return x + inputs + + class ConvSequence(nn.Module): + def __init__(self, input_shape, out_channels): + super().__init__() + self._input_shape = input_shape + self._out_channels = out_channels + self.conv = nn.Conv2d(in_channels=self._input_shape[0], out_channels=self._out_channels, kernel_size=3, padding=1) + self.res_block0 = ResidualBlock(self._out_channels) + self.res_block1 = ResidualBlock(self._out_channels) + + def forward(self, x): + x = self.conv(x) + x = nn.functional.max_pool2d(x, kernel_size=3, stride=2, padding=1) + x = self.res_block0(x) + x = self.res_block1(x) + assert x.shape[1:] == self.get_output_shape() + return x + + def get_output_shape(self): + _c, h, w = self._input_shape + return (self._out_channels, (h + 1) // 2, (w + 1) // 2) + + class Agent(nn.Module): + def __init__(self, envs): + super().__init__() + h, w, c = envs.single_observation_space.shape + shape = (c, h, w) + conv_seqs = [] + for out_channels in [16, 32, 32]: + conv_seq = ConvSequence(shape, out_channels) + shape = conv_seq.get_output_shape() + conv_seqs.append(conv_seq) + conv_seqs += [ + nn.Flatten(), + nn.ReLU(), + nn.Linear(in_features=shape[0] * shape[1] * shape[2], out_features=256), + nn.ReLU(), + ] + self.network = nn.Sequential(*conv_seqs) + self.actor = layer_init(nn.Linear(256, envs.single_action_space.n), std=0.01) + self.critic = layer_init(nn.Linear(256, 1), std=1) + + def get_value(self, x): + return self.critic(self.network(x.permute((0, 3, 1, 2)) / 255.0)) # "bhwc" -> "bchw" + + def get_action_and_value(self, x, action=None, full_log_probs=False): + hidden = self.network(x.permute((0, 3, 1, 2)) / 255.0) # "bhwc" -> "bchw" + value = self.critic(hidden) + logits = self.actor(hidden) + dist = Categorical(logits=logits) + if action is None: + action = dist.sample() + + action_log_probs = torch.squeeze(dist.log_prob(action)) + dist_entropy = dist.entropy() + + if full_log_probs: + log_probs = torch.log(dist.probs) + return action, action_log_probs, dist_entropy, value, log_probs + + return action, action_log_probs, dist_entropy, value + + def make_env(env_id, seed): + def thunk(): + env = gym.make(f"procgen-{env_id}-v0", rand_seed=seed, distribution_mode="easy") + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + return thunk + + envs = gym.vector.SyncVectorEnv( + [ + make_env("bigfish", 1 + i) + for i in range(1) + ] + ) + agent = Agent(envs).to("cuda") + optimizer = optim.Adam(agent.parameters(), lr=5e-4, eps=1e-5) + + online_params = [] + for name, module in agent.named_children(): + if not name.startswith('params'): + online_params.append(name) + weight_recycler = NeuronRecycler(online_params) + rng = jax.random.PRNGKey(1) + + for update_step in range(10000): + # Neuron/layer recycling starts if reset_mode is not None. + # Otherwise, we log dead neurons over training for standard agent. + is_intermediated = weight_recycler.is_intermediated_required(update_step) + + # get intermediate activation per layer to calculate neuron score. + def get_intermediates(online_params): + batch = self._sample_batch_for_statistics() + + def apply_data(x): + filter_rep = lambda l, _: l.name is not None and 'act' in l.name + return self.network_def.apply( + online_params, + x, + capture_intermediates=filter_rep, + mutable=['intermediates']) + + _, state = jax.vmap(apply_data)(batch) + return state['intermediates'] + + intermediates = get_intermediates(agent.parameters()) if is_intermediated else None + log_dict_neurons = (weight_recycler.maybe_log_deadneurons(update_step, intermediates)) + + # Neuron/layer recyling. + rng, key = jax.random.split(rng) + online_params, opt_state = weight_recycler.maybe_update_weights( + update_step, intermediates, online_params, key, optimizer.state_dict() + ) + optimizer.load_state_dict(opt_state) + with torch.no_grad(): + for agent_param, new_param in zip(agent.parameters().values(), online_params.values()): + agent_param.data = new_param diff --git a/examples/experimental/rllib_cartpole.py b/examples/experimental/rllib_cartpole.py new file mode 100644 index 00000000..8877a279 --- /dev/null +++ b/examples/experimental/rllib_cartpole.py @@ -0,0 +1,40 @@ +import gymnasium as gym +from gymnasium.spaces import Box +from ray import tune +from ray.tune.registry import register_env +from syllabus.core import RaySyncWrapper, make_ray_curriculum +from syllabus.curricula import SimpleBoxCurriculum +from syllabus.task_space import TaskSpace + +from .task_wrappers import CartPoleTaskWrapper + +# Define a task space +if __name__ == "__main__": + task_space = TaskSpace(Box(-0.3, 0.3, shape=(2,))) + + def env_creator(config): + env = gym.make("CartPole-v1") + # Wrap the environment to change tasks on reset() + env = CartPoleTaskWrapper(env) + # Add environment sync wrapper + env = RaySyncWrapper( + env, task_space=task_space, update_on_step=False + ) + return env + + register_env("task_cartpole", env_creator) + + # Create the curriculum + curriculum = SimpleBoxCurriculum(task_space) + # Add the curriculum sync wrapper + curriculum = make_ray_curriculum(curriculum) + + config = { + "env": "task_cartpole", + "num_gpus": 1, + "num_workers": 8, + "framework": "torch", + } + + tuner = tune.Tuner("APEX", param_space=config) + results = tuner.fit() diff --git a/examples/experimental/sb3_procgen_plr.py b/examples/experimental/sb3_procgen_plr.py new file mode 100644 index 00000000..f9cde9f9 --- /dev/null +++ b/examples/experimental/sb3_procgen_plr.py @@ -0,0 +1,116 @@ +from typing import Callable + +import gym +import procgen # noqa: F401 +import wandb +from stable_baselines3 import PPO +from stable_baselines3.common.callbacks import BaseCallback, CallbackList +from stable_baselines3.common.vec_env import (DummyVecEnv, VecMonitor, + VecNormalize) +from syllabus.core import (MultiProcessingSyncWrapper, + make_multiprocessing_curriculum) +from syllabus.curricula import CentralizedPrioritizedLevelReplay +from syllabus.examples.task_wrappers import ProcgenTaskWrapper +from wandb.integration.sb3 import WandbCallback + + +def make_env(task_queue, update_queue, start_level=0, num_levels=1): + def thunk(): + env = gym.make("procgen-bigfish-v0", distribution_mode="easy", start_level=start_level, num_levels=num_levels) + env = ProcgenTaskWrapper(env) + env = MultiProcessingSyncWrapper( + env, + task_queue, + update_queue, + update_on_step=False, + task_space=env.task_space, + ) + return env + return thunk + + +def wrap_vecenv(vecenv): + vecenv.is_vector_env = True + vecenv = VecMonitor(venv=vecenv, filename=None) + vecenv = VecNormalize(venv=vecenv, norm_obs=False, norm_reward=True) + return vecenv + + +class CustomCallback(BaseCallback): + """ + A custom callback that derives from ``BaseCallback``. + + :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages + """ + def __init__(self, curriculum, verbose=0): + super().__init__(verbose) + self.curriculum = curriculum + + def _on_step(self) -> bool: + tasks = self.training_env.venv.venv.venv.get_attr("task") + + update = { + "update_type": "on_demand", + "metrics": { + "value": self.locals["values"], + "next_value": self.locals["values"], + "rew": self.locals["rewards"], + "dones": self.locals["dones"], + "tasks": tasks, + }, + } + self.curriculum.update_curriculum(update) + return True + + +def linear_schedule(initial_value: float) -> Callable[[float], float]: + def func(progress_remaining: float) -> float: + return progress_remaining * initial_value + return func + + +run = wandb.init( + project="sb3", + entity="ryansullivan", + sync_tensorboard=True, # auto-upload sb3's tensorboard metrics + save_code=True, # optional +) + + +sample_env = gym.make("procgen-bigfish-v0") +sample_env = ProcgenTaskWrapper(sample_env) +curriculum = CentralizedPrioritizedLevelReplay(sample_env.task_space, num_processes=64, num_steps=256) +curriculum, task_queue, update_queue = make_multiprocessing_curriculum(curriculum) +venv = DummyVecEnv( + [ + make_env(task_queue, update_queue, num_levels=0) + for i in range(64) + ] +) +venv = wrap_vecenv(venv) + +model = PPO( + "CnnPolicy", + venv, + verbose=1, + n_steps=256, + learning_rate=linear_schedule(0.0005), + gamma=0.999, + gae_lambda=0.95, + n_epochs=3, + clip_range_vf=0.2, + ent_coef=0.01, + batch_size=256 * 64, + tensorboard_log="runs/testing" +) + +wandb_callback = WandbCallback( + model_save_path=f"models/{run.id}", + verbose=2, +) +plr_callback = CustomCallback(curriculum) +callback = CallbackList([wandb_callback, plr_callback]) +model.learn( + 25000000, + callback=callback, +) diff --git a/examples/experimental/torchbeast_nethack.py b/examples/experimental/torchbeast_nethack.py new file mode 100644 index 00000000..21767b1f --- /dev/null +++ b/examples/experimental/torchbeast_nethack.py @@ -0,0 +1,1280 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is an example self-contained agent running NLE based on MonoBeast. + +import argparse +import logging +import math +import os +import pprint +import threading +import time +import timeit +import traceback + +import numpy as np +import wandb +from syllabus.core import (MultiProcessingSyncWrapper, + make_multiprocessing_curriculum) +from syllabus.curricula import LearningProgressCurriculum +from syllabus.examples import NethackTaskWrapper + +# Necessary for multithreading. +os.environ["OMP_NUM_THREADS"] = "1" + +try: + import torch + from torch import multiprocessing as mp + from torch import nn + from torch.nn import functional as F +except ImportError: + logging.exception( + "PyTorch not found. Please install the agent dependencies with " + '`pip install "nle[agent]"`' + ) + +import gym # noqa: E402 +import nle # noqa: F401, E402 +from nle import nethack # noqa: E402 +from nle.agent import vtrace # noqa: E402 + + +def parse_args(): + # yapf: disable + parser = argparse.ArgumentParser(description="PyTorch Scalable Agent") + + parser.add_argument("--env", type=str, default="NetHackScore-v0", + help="Gym environment.") + parser.add_argument("--mode", default="train", + choices=["train", "test", "test_render"], + help="Training or test mode.") + parser.add_argument("--profile", action="store_true", + help="Profile main process.") + parser.add_argument("--profile_worker", action="store_true", + help="Profile worker process.") + + # Training settings. + parser.add_argument("--disable_checkpoint", action="store_true", + help="Disable saving checkpoint.") + parser.add_argument("--savedir", default="~/torchbeast/", + help="Root dir where experiment data will be saved.") + parser.add_argument("--num_actors", default=8, type=int, metavar="N", + help="Number of actors (default: 4).") + parser.add_argument("--total_steps", default=100000, type=int, metavar="T", + help="Total environment steps to train for.") + parser.add_argument("--batch_size", default=16, type=int, metavar="B", + help="Learner batch size.") + parser.add_argument("--unroll_length", default=80, type=int, metavar="T", + help="The unroll length (time dimension).") + parser.add_argument("--num_buffers", default=None, type=int, + metavar="N", help="Number of shared-memory buffers.") + parser.add_argument("--num_learner_threads", "--num_threads", default=2, type=int, + metavar="N", help="Number learner threads.") + parser.add_argument("--disable_cuda", action="store_true", + help="Disable CUDA.") + parser.add_argument("--use_lstm", action="store_true", + help="Use LSTM in agent model.") + parser.add_argument("--save_ttyrec_every", default=1000, type=int, + metavar="N", help="Save ttyrec every N episodes.") + parser.add_argument("--save_video", action="store_true", + help="Save and log video during training.") + + # Curriculum Settings + parser.add_argument("--curriculum", action="store_true", + help="Use Syllabus curricula.") + + # Testing settings + parser.add_argument("--reward_frames", action="store_true", + help="Only print reward frames and show inventory.") + parser.add_argument("--item_frames", action="store_true", + help="Only print frames where the agent picks up items and show inventory.") + parser.add_argument("--message", action="store_true", + help="Set to true without the above options to display only the messages.") + parser.add_argument("--custompath", type=str, + help="Set a custom path to draw a tar test file from.") + + # Weights and Biases settings + parser.add_argument("--exp_name", type=str, default="nle_baseline", + help="Set name for wandb experiment.") + parser.add_argument("--wandb_id", default=1, type=int, + help="Set id for wandb experiment.") + + # Loss settings. + parser.add_argument("--entropy_cost", default=0.0006, + type=float, help="Entropy cost/multiplier.") + parser.add_argument("--baseline_cost", default=0.5, + type=float, help="Baseline cost/multiplier.") + parser.add_argument("--discounting", default=0.99, + type=float, help="Discounting factor.") + parser.add_argument("--reward_clipping", default="abs_one", + choices=["abs_one", "none"], + help="Reward clipping.") + + # Optimizer settings. + parser.add_argument("--learning_rate", default=0.00048, + type=float, metavar="LR", help="Learning rate.") + parser.add_argument("--alpha", default=0.99, type=float, + help="RMSProp smoothing constant.") + parser.add_argument("--momentum", default=0, type=float, + help="RMSProp momentum.") + parser.add_argument("--epsilon", default=0.01, type=float, + help="RMSProp epsilon.") + parser.add_argument("--grad_norm_clipping", default=40.0, type=float, + help="Global gradient norm clip.") + # yapf: enable + args = parser.parse_args() + args.exp_name = f"nethack__{args.exp_name}__{int(time.time())}" + return args + + +logging.basicConfig( + format=("[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"), + level=logging.INFO, + #filename="monobeast.log" +) + + +def nested_map(f, n): + if isinstance(n, tuple) or isinstance(n, list): + return n.__class__(nested_map(f, sn) for sn in n) + elif isinstance(n, dict): + return {k: nested_map(f, v) for k, v in n.items()} + else: + return f(n) + + +def compute_baseline_loss(advantages): + return 0.5 * torch.sum(advantages**2) + + +def compute_entropy_loss(logits): + """Return the entropy loss, i.e., the negative entropy of the policy.""" + policy = F.softmax(logits, dim=-1) + log_policy = F.log_softmax(logits, dim=-1) + return torch.sum(policy * log_policy) + + +def compute_policy_gradient_loss(logits, actions, advantages): + cross_entropy = F.nll_loss( + F.log_softmax(torch.flatten(logits, 0, 1), dim=-1), + target=torch.flatten(actions, 0, 1), + reduction="none", + ) + cross_entropy = cross_entropy.view_as(advantages) + return torch.sum(cross_entropy * advantages.detach()) + + +def create_env( + name, + *args, + chars=False, + task_queue: mp.SimpleQueue = None, + complete_queue: mp.SimpleQueue = None, + step_queue: mp.SimpleQueue = None, + observation_keys=None, + **kwargs +): + if flags.curriculum: + observation_keys = ("glyphs", "blstats", "message") + #observation_keys = ("glyphs", "blstats", "inv_glyphs", "inv_strs", "inv_letters", "inv_oclasses", + # "message", "tty_chars", "tty_colors", "tty_cursor", "internal") + if chars: + observation_keys += ("chars",) + + env = gym.make(name, *args, observation_keys=observation_keys, penalty_step=0.0, **kwargs) + env = NethackTaskWrapper(env) + env = MultiProcessingSyncWrapper(env, + task_queue, + complete_queue, + step_queue=step_queue, + update_on_step=True, + default_task=0, + task_space=env.task_space) + else: + env = gym.make(name, *args, observation_keys=observation_keys, **kwargs) + env = ResettingEnvironment(env) + return env + + +def act( + flags, + actor_index: int, + free_queue: mp.SimpleQueue, + full_queue: mp.SimpleQueue, + model: torch.nn.Module, + buffers, + initial_agent_state_buffers, + task_queue: mp.SimpleQueue = None, + complete_queue: mp.SimpleQueue = None, + step_queue: mp.SimpleQueue = None, + + video_dir=None +): + try: + if flags.profile_worker and actor_index == 0: + import cProfile + from pstats import Stats + pr = cProfile.Profile() + pr.enable() + + logging.info("Actor %i started.", actor_index) + + gym_env = create_env( + flags.env, + savedir=flags.rundir, + save_ttyrec_every=flags.save_ttyrec_every, + observation_keys=("glyphs", "blstats"), + task_queue=task_queue, + complete_queue=complete_queue, + step_queue=step_queue, + ) + + env = gym_env + env_output = env.initial() + agent_state = model.initial_state(batch_size=1) + agent_output, _ = model(env_output, agent_state) + env_output = env.step(agent_output["action"]) + + while True: + index = free_queue.get() + if index is None: + break + + # Write old rollout end. + for key in env_output: + buffers[key][index][0, ...] = env_output[key] + for key in agent_output: + buffers[key][index][0, ...] = agent_output[key] + for i, tensor in enumerate(agent_state): + initial_agent_state_buffers[index][i][...] = tensor + + # Do new rollout. + for t in range(flags.unroll_length): + with torch.no_grad(): + agent_output, agent_state = model(env_output, agent_state) + + env_output = env.step(agent_output["action"]) + if flags.curriculum and "info" in env_output: + info = env_output["info"] + + for key in env_output: + buffers[key][index][t + 1, ...] = env_output[key] + for key in agent_output: + buffers[key][index][t + 1, ...] = agent_output[key] + + full_queue.put(index) + + if flags.profile_worker and actor_index == 0: + pr.disable() + stats = Stats(pr) + stats.sort_stats('cumtime').print_stats(200) + + except KeyboardInterrupt: + pass # Return silently. + except Exception: + logging.error("Exception in worker process %i", actor_index) + traceback.print_exc() + print() + raise + + +def get_batch( + flags, + free_queue: mp.SimpleQueue, + full_queue: mp.SimpleQueue, + buffers, + initial_agent_state_buffers, + lock=threading.Lock(), +): + with lock: + indices = [full_queue.get() for _ in range(flags.batch_size)] + batch = { + key: torch.stack([buffers[key][m] for m in indices], dim=1) for key in buffers + } + + initial_agent_state = ( + torch.cat(ts, dim=1) + for ts in zip(*[initial_agent_state_buffers[m] for m in indices]) + ) + for m in indices: + free_queue.put(m) + batch = {k: t.to(device=flags.device, non_blocking=True) for k, t in batch.items()} + initial_agent_state = tuple( + t.to(device=flags.device, non_blocking=True) for t in initial_agent_state + ) + return batch, initial_agent_state + + +def learn( + flags, + actor_model, + model, + batch, + initial_agent_state, + optimizer, + scheduler, + lock=threading.Lock(), # noqa: B008 +): + """Performs a learning (optimization) step.""" + with lock: + learner_outputs, unused_state = model(batch, initial_agent_state) + + # Take final value function slice for bootstrapping. + bootstrap_value = learner_outputs["baseline"][-1] + + # Move from obs[t] -> action[t] to action[t] -> obs[t]. + batch = {key: tensor[1:] for key, tensor in batch.items()} + learner_outputs = {key: tensor[:-1] for key, tensor in learner_outputs.items()} + + rewards = batch["reward"] + if flags.reward_clipping == "abs_one": + clipped_rewards = torch.clamp(rewards, -1, 1) + elif flags.reward_clipping == "none": + clipped_rewards = rewards + + discounts = (~batch["done"]).float() * flags.discounting + + vtrace_returns = vtrace.from_logits( + behavior_policy_logits=batch["policy_logits"], + target_policy_logits=learner_outputs["policy_logits"], + actions=batch["action"], + discounts=discounts, + rewards=clipped_rewards, + values=learner_outputs["baseline"], + bootstrap_value=bootstrap_value, + ) + + pg_loss = compute_policy_gradient_loss( + learner_outputs["policy_logits"], + batch["action"], + vtrace_returns.pg_advantages, + ) + baseline_loss = flags.baseline_cost * compute_baseline_loss( + vtrace_returns.vs - learner_outputs["baseline"] + ) + entropy_loss = flags.entropy_cost * compute_entropy_loss( + learner_outputs["policy_logits"] + ) + + total_loss = pg_loss + baseline_loss + entropy_loss + + episode_returns = batch["episode_return"][batch["done"]] + episode_steps = batch["episode_step"][batch["done"]] + + stats = { + "episode_returns": tuple(episode_returns.cpu().numpy()), + "episode_lengths": tuple(episode_steps.cpu().numpy()), + "mean_episode_length": torch.mean(episode_steps).item(), + "mean_episode_return": torch.mean(episode_returns).item(), + "total_loss": total_loss.item(), + "pg_loss": pg_loss.item(), + "baseline_loss": baseline_loss.item(), + "entropy_loss": entropy_loss.item(), + } + + optimizer.zero_grad() + total_loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping) + optimizer.step() + scheduler.step() + + actor_model.load_state_dict(model.state_dict()) + return stats + + +def create_buffers(flags, observation_space, num_actions, num_overlapping_steps=1): + size = (flags.unroll_length + num_overlapping_steps,) + + # Get specimens to infer shapes and dtypes. + samples = {k: torch.from_numpy(v) for k, v in observation_space.sample().items()} + + specs = { + key: dict(size=size + sample.shape, dtype=sample.dtype) + for key, sample in samples.items() + } + specs.update( + reward=dict(size=size, dtype=torch.float32), + done=dict(size=size, dtype=torch.bool), + episode_return=dict(size=size, dtype=torch.float32), + episode_step=dict(size=size, dtype=torch.float32), + policy_logits=dict(size=size + (num_actions,), dtype=torch.float32), + baseline=dict(size=size, dtype=torch.float32), + last_action=dict(size=size, dtype=torch.int64), + action=dict(size=size, dtype=torch.int64), + ) + buffers = {key: [] for key in specs} + for _ in range(flags.num_buffers): + for key in buffers: + buffers[key].append(torch.empty(**specs[key]).share_memory_()) + return buffers + + +def _format_observations(observation): + observations = {} + for key in list(observation.keys()): + entry = observation[key] + if isinstance(entry, np.ndarray): + entry = torch.from_numpy(entry) + entry = entry.view((1, 1) + entry.shape) # (...) -> (T,B,...). + observations[key] = entry + return observations + + +class ResettingEnvironment: + """Turns a Gym environment into something that can be step()ed indefinitely.""" + + def __init__(self, gym_env): + self.gym_env = gym_env + self.episode_return = None + self._copy_gym_properties() + + def _copy_gym_properties(self): + self.action_space = self.gym_env.action_space + self.observation_space = self.gym_env.observation_space + self.reward_range = self.gym_env.reward_range + self.metadata = self.gym_env.metadata + if flags.curriculum: + self.task_space = self.gym_env.task_space + + def initial(self): + initial_reward = torch.zeros(1, 1) + # This supports only single-tensor actions ATM. + initial_last_action = torch.zeros(1, 1, dtype=torch.int64) + self.episode_return = torch.zeros(1, 1) + self.episode_step = torch.zeros(1, 1, dtype=torch.float32) + initial_done = torch.ones(1, 1, dtype=torch.uint8) + + result = _format_observations(self.gym_env.reset()) + result.update( + reward=initial_reward, + done=initial_done, + episode_return=self.episode_return, + episode_step=self.episode_step, + last_action=initial_last_action, + ) + return result + + def step(self, action): + observation, reward, done, info = self.gym_env.step(action.item()) + self.episode_step += 1 + self.episode_return += reward + episode_step = self.episode_step + episode_return = self.episode_return + + if done: + observation = self.gym_env.reset() + self.episode_return = torch.zeros(1, 1) + self.episode_step = torch.zeros(1, 1, dtype=torch.float32) + + result = _format_observations(observation) + + reward = torch.tensor(reward).view(1, 1) + done = torch.tensor(done).view(1, 1) + if flags.curriculum: + result.update( + reward=reward, + done=done, + episode_return=episode_return, + episode_step=episode_step, + last_action=action, + ) + else: + result.update( + reward=reward, + done=done, + episode_return=episode_return, + episode_step=episode_step, + last_action=action, + ) + + return result + + def close(self): + self.gym_env.close() + + +def parse_logpaths(flags): + flags.savedir = os.path.expandvars(os.path.expanduser(flags.savedir)) + + if flags.exp_name: + rundir = os.path.join(flags.savedir, f"torchbeast-{flags.exp_name}-{flags.wandb_id}") + else: + rundir = os.path.join(flags.savedir, "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")) + + checkpointpath = os.path.join(rundir, "model.tar") + + # TODO: Check if run name + id exists, and resume + resume_checkpoint = None + if os.path.exists(rundir): + resume_checkpoint = torch.load(checkpointpath) + # TODO: Make sure this doesn't overwrite anything important, + # or that we want to be able to change + # flags = Namespace(**resume_checkpoint["flags"]) + else: + os.makedirs(rundir) + + logfile = open(os.path.join(rundir, "logs.tsv"), "a", buffering=1) + logging.info("Logging results to %s", rundir) + + symlink = os.path.join(flags.savedir, "latest") + try: + if os.path.islink(symlink): + os.remove(symlink) + if not os.path.exists(symlink): + os.symlink(rundir, symlink) + logging.info("Symlinked log directory: %s", symlink) + except OSError: + raise + + flags.rundir = rundir + return flags, rundir, checkpointpath, resume_checkpoint, logfile + + +def train(flags, wandb_run=None): # pylint: disable=too-many-branches, too-many-statements + # Set all filepaths using provided arguments + flags, rundir, checkpointpath, resume_checkpoint, logfile = parse_logpaths(flags) + + if flags.num_buffers is None: # Set sensible default for num_buffers. + flags.num_buffers = max(2 * flags.num_actors, flags.batch_size) + if flags.num_actors >= flags.num_buffers: + raise ValueError("num_buffers should be larger than num_actors") + if flags.num_buffers < flags.batch_size: + raise ValueError("num_buffers should be larger than batch_size") + + T = flags.unroll_length + B = flags.batch_size + + flags.device = None + if not flags.disable_cuda and torch.cuda.is_available(): + logging.info("Using CUDA.") + flags.device = torch.device("cuda") + else: + logging.info("Not using CUDA.") + flags.device = torch.device("cpu") + + sample_env = create_env(flags.env, observation_keys=("glyphs", "blstats")) + observation_space = sample_env.observation_space + action_space = sample_env.action_space + if flags.curriculum: + task_space = sample_env.task_space + + model = Net(observation_space, action_space.n, flags.use_lstm, goal=flags.curriculum) + buffers = create_buffers(flags, observation_space, model.num_actions) + + model.share_memory() + + def lr_lambda(epoch): + return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps + + # Resume training + # if resume_checkpoint: + # # TODO: Fix resume code + # #curriculum = LearningProgressCurriculum(tasks=resume_checkpoint["curriculum"]) + # if flags.curriculum: + # curriculum = LearningProgressCurriculum(task_space=task_space) + # model.load_state_dict(resume_checkpoint["model"]) + # learner_model = Net(observation_space, action_space.n, flags.use_lstm, goal=flags.curriculum).to( + # device=flags.device + # ) + # learner_model.load_state_dict(model.state_dict()) + # optimizer = torch.optim.RMSprop( + # learner_model.parameters(), + # lr=flags.learning_rate, + # momentum=flags.momentum, + # eps=flags.epsilon, + # alpha=flags.alpha, + # ) + # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + # scheduler.load_state_dict(resume_checkpoint["scheduler"]) + # optimizer.load_state_dict(resume_checkpoint['optimizer']) + # step = resume_checkpoint["step"] + # + # else: + task_queue, complete_queue, step_queue = None, None, None + if flags.curriculum: + name_list = None + if isinstance(sample_env.task_space, gym.spaces.Discrete): + task_list = sample_env.gym_env.task_list + name_list = [task_list[idx].__name__ for idx in range(len(task_list))] + curriculum, task_queue, complete_queue, step_queue = make_multiprocessing_curriculum(LearningProgressCurriculum, + task_space, + random_start_tasks=0, + task_names=name_list) + + learner_model = Net(observation_space, action_space.n, flags.use_lstm, goal=flags.curriculum).to(device=flags.device) + learner_model.load_state_dict(model.state_dict()) + + optimizer = torch.optim.RMSprop( + learner_model.parameters(), + lr=flags.learning_rate, + momentum=flags.momentum, + eps=flags.epsilon, + alpha=flags.alpha, + ) + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + step = 0 + + # Add initial RNN state. + initial_agent_state_buffers = [] + for _ in range(flags.num_buffers): + state = model.initial_state(batch_size=1) + for t in state: + t.share_memory_() + initial_agent_state_buffers.append(state) + + del sample_env # End this before forking. + + actor_processes = [] + ctx = mp.get_context("fork") + free_queue = ctx.SimpleQueue() + full_queue = ctx.SimpleQueue() + + for i in range(flags.num_actors): + actor = ctx.Process( + target=act, + args=( + flags, i, + free_queue, full_queue, + model, buffers, + initial_agent_state_buffers, + task_queue, complete_queue, step_queue, + rundir, + ), + name="Actor-%i" % i, + ) + actor.start() + actor_processes.append(actor) + + if flags.exp_name: + wandb.config = { + "learning_rate": flags.learning_rate, + "epsilon": flags.epsilon, + "alpha": flags.alpha, + "momentum": flags.momentum, + } + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + stat_keys = ["total_loss", "mean_episode_return", "pg_loss", "baseline_loss", "entropy_loss"] + logfile.write("# Step\t%s\n" % "\t".join(stat_keys)) + + all_stats = [] + stats = {} + + def batch_and_learn(i, lock=threading.Lock()): + """Thread target for the learning process.""" + nonlocal step, stats, all_stats + while step < flags.total_steps: + batch, agent_state = get_batch(flags, free_queue, full_queue, buffers, initial_agent_state_buffers) + stats = learn(flags, model, learner_model, batch, agent_state, optimizer, scheduler) + + all_stats.append(stats) + with lock: + logfile.write("%i\t" % step) + logfile.write("\t".join(str(stats[k]) for k in stat_keys)) + logfile.write("\n") + step += T * B + + for m in range(flags.num_buffers): + free_queue.put(m) + + threads = [] + for i in range(flags.num_learner_threads): + thread = threading.Thread( + target=batch_and_learn, + name="batch-and-learn-%d" % i, + args=(i,), + daemon=True, # To support KeyboardInterrupt below. + ) + thread.start() + threads.append(thread) + + def checkpoint(): + if flags.disable_checkpoint: + return + logging.info("Saving checkpoint to %s", checkpointpath) + torch.save( + { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "flags": vars(flags), + "step": step + }, + checkpointpath, + ) + + timer = timeit.default_timer + recent_sps = [1.0] * 5 + try: + if flags.curriculum: + curriculum.log_metrics(step) + last_checkpoint_time = timer() + while step < flags.total_steps: + start_step = step + start_time = timer() + time.sleep(10) + + if timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. + checkpoint() + last_checkpoint_time = timer() + #wandb.gym.monitor() + + # # Combine stats + all_stats_dict = {} + if len(all_stats) > 0: + # Iterate through keys and combine based on data type + for key, value in all_stats[0].items(): + all_stats_dict[key] = 0 + stat_list = [] + if isinstance(value, (int, float, np.uint8, np.float32)): + # Average values + for stat_dict in all_stats: + stat_value = stat_dict.get(key) + if stat_value is not None and not np.isnan(stat_value) and not math.isnan(stat_value): + stat_list.append(stat_value) + all_stats_dict[key] = (sum(stat_list) / len(stat_list)) if len(stat_list) > 0 else 0 + elif isinstance(value, (list, tuple)): + # Combine lists + for stat_dict in all_stats: + stat_value = stat_dict.get(key) + if stat_value is not None and stat_value != () and stat_value != []: + stat_list += stat_value + all_stats_dict[key] = stat_list + all_stats = [] + + # Remove clutter + if "episode_returns" in all_stats_dict: + del all_stats_dict["episode_returns"] + if "episode_lengths" in all_stats_dict: + del all_stats_dict["episode_lengths"] + + # Log run data to weights and biases + if flags.exp_name: + wandb_stats = all_stats_dict + for key, value in wandb_stats.items(): + if isinstance(value, (int, float, np.uint8, np.float32)): + wandb_stats[key] = 0.0 if np.isnan(value) else value + if not flags.curriculum: + wandb_stats["mean_score_return"] = wandb_stats["mean_episode_return"] + wandb_stats["learning_rate"] = scheduler.get_last_lr()[0] + # task_table.add_data(step, str(curriculum.export_task_names())) + wandb.log(wandb_stats, step=step) + + sps = (step - start_step) / (timer() - start_time) + + if stats.get("episode_returns", None): + mean_return = ( + "Return per episode: %.1f. " % stats["mean_episode_return"] + ) + else: + mean_return = "" + total_loss = stats.get("total_loss", float("inf")) + log_str = "Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s" + log_args = [step, sps, total_loss, mean_return, pprint.pformat(all_stats_dict)] + # log_str = "Steps %i @ %.1f SPS. Loss %f. %s" + # log_args = [step, sps, total_loss, mean_return] + # if flags.curriculum: + # # TODO: Fix this + # task_names = curriculum.export_task_names()[:] + # log_str += "\nTasks: %s\nL_prog: %s\nP_fast: %s" + # log_args.append(task_names) + # # Get learning progress metric + # lps = curriculum.metric_for_tasks(task_names, metric="lp") + # lps = [f"{lp:.4f}" for lp in lps] + # log_args.append(lps) + # # Get recent estimate of success rates + # pfasts = curriculum.metric_for_tasks(task_names, metric="p_fast") + # pfasts = [f"{pfast:.4f}" for pfast in pfasts] + # log_args.append(pfasts) + logging.info(log_str, *log_args) + if flags.curriculum: + curriculum.log_metrics(step) + + # Stop training if sps remains 0 for too long + if total_loss != float("inf"): + for i in reversed(range(1, len(recent_sps))): + recent_sps[i] = recent_sps[i-1] + recent_sps[0] = sps + if sum(recent_sps) == 0.0: + return + + except KeyboardInterrupt: + logging.warning("Quitting.") + return # Try joining actors then quit. + else: + for thread in threads: + thread.join() + logging.info("Learning finished after %d steps.", step) + finally: + for _ in range(flags.num_actors): + free_queue.put(None) + for actor in actor_processes: + actor.join(timeout=1) + + checkpoint() + logfile.close() + + +def test(flags, num_episodes=1): + flags.savedir = os.path.expandvars(os.path.expanduser(flags.savedir)) + print("savedir:" + str(flags.savedir)) + checkpointpath = flags.custompath if flags.custompath else os.path.join(flags.savedir, "latest", "model.tar") + + #gym_env = create_env(flags.env, save_ttyrecs=flags.save_ttyrecs) + observation_keys = ("glyphs", "blstats") + observation_keys = ("glyphs", "blstats", "inv_glyphs", "inv_strs", "inv_letters", "inv_oclasses", + "message", "tty_chars", "tty_colors", "tty_cursor") + gym_env = create_env(flags.env, observation_keys=observation_keys) + env = ResettingEnvironment(gym_env) + model = Net(gym_env.observation_space, gym_env.action_space.n, flags.use_lstm, goal=flags.curriculum) + model.eval() + checkpoint = torch.load(checkpointpath, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + + observation = env.initial() + returns = [] + + agent_state = model.initial_state(batch_size=1) + last_inv = [] + + while len(returns) < num_episodes: + if flags.mode == "test_render": + if flags.item_frames or flags.reward_frames: + #The following lines parse the inventory and store it as an array in all_items_txt + all_items_text = [] + all_items_encoded = observation["inv_strs"][0][0].numpy() + for x in range(55): + if(all_items_encoded[x][0] != 0): + all_items_text.append(''.join([str((chr(elem))) for elem in all_items_encoded[x][all_items_encoded[x] != 0]])) + if observation["reward"].item() > 0 and flags.reward_frames or all_items_text != last_inv and flags.item_frames: + env.gym_env.render() + print("Reward is %d" % observation["reward"].item()) + print(all_items_text) + last_inv = all_items_text + elif flags.message: + #Parse and print the message + message_encoded = observation["message"][0][0] + message = "" + for x in range(256): + message += chr(message_encoded[x].item()) + print(message) + else: + env.gym_env.render() + policy_outputs, agent_state = model(observation, agent_state) + observation = env.step(policy_outputs["action"]) + if observation["done"].item(): + last_inv = [] + returns.append(observation["episode_return"].item()) + logging.info( + "Episode ended after %d steps. Return: %.1f", + observation["episode_step"].item(), + observation["episode_return"].item(), + ) + env.close() + logging.info( + "Average returns over %i steps: %.1f", num_episodes, sum(returns) / len(returns) + ) + + +class RandomNet(nn.Module): + def __init__(self, observation_shape, num_actions, use_lstm): + super(RandomNet, self).__init__() + del observation_shape, use_lstm + self.num_actions = num_actions + self.theta = torch.nn.Parameter(torch.zeros(self.num_actions)) + + def forward(self, inputs, core_state): + # print(inputs) + T, B, *_ = inputs["observation"].shape + zeros = self.theta * 0 + # set logits to 0 + policy_logits = zeros[None, :].expand(T * B, -1) + # set baseline to 0 + baseline = policy_logits.sum(dim=1).view(-1, B) + + # sample random action + action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1).view( + T, B + ) + policy_logits = policy_logits.view(T, B, self.num_actions) + return ( + dict(policy_logits=policy_logits, baseline=baseline, action=action), + core_state, + ) + + def initial_state(self, batch_size): + return () + + +def _step_to_range(delta, num_steps): + """ + Range of `num_steps` values separated by distance `delta` centered around zero. + Given an image with x,y in [-1, 1], return represents the crop range and sampling points on one axis. + """ + delta_range = delta * torch.arange(-(num_steps // 2), (num_steps + 1) // 2) + return delta_range + + +class Crop(nn.Module): + """Helper class for NetHackNet below.""" + + def __init__(self, height, width, height_target, width_target): + super(Crop, self).__init__() + self.width = width # 79 + self.height = height # 21 + self.width_target = width_target # 9 + self.height_target = height_target # 9 + + # Create row-wise sampling grid and repeat across height_target rows + width_grid = _step_to_range(2 / (self.width - 1), self.width_target)[None, :].expand(self.height_target, -1) + # Create column-wise sampling grid and repeat across width_target columns + height_grid = _step_to_range(2 / (self.height - 1), height_target)[:, None].expand(-1, self.width_target) + + # "clone" necessary, https://github.com/pytorch/pytorch/issues/34880 + self.register_buffer("width_grid", width_grid.clone()) + self.register_buffer("height_grid", height_grid.clone()) + + def forward(self, inputs, coordinates): + """Calculates centered crop around given x,y coordinates. + Args: + inputs [B x H x W] + coordinates [B x 2] x,y coordinates + Returns: + [B x H' x W'] inputs cropped and centered around x,y coordinates. + """ + + assert inputs.shape[1] == self.height + assert inputs.shape[2] == self.width + + # Add another axis at 1 + inputs = inputs[:, None, :, :].float() + + # Extract x,y coordinates for each sample + x = coordinates[:, 0] + y = coordinates[:, 1] + + # Recenter x and y values around 0 and normalize to range [-1, 1] + x_shift = 2 / (self.width - 1) * (x.float() - self.width // 2) + y_shift = 2 / (self.height - 1) * (y.float() - self.height // 2) + + # Shifts sampling grid to be centered around x, y + grid = torch.stack( + [ + self.width_grid[None, :, :] + x_shift[:, None, None], + self.height_grid[None, :, :] + y_shift[:, None, None], + ], + dim=3, + ) + + # TODO: only cast to int if original tensor was int + return ( + torch.round(F.grid_sample(inputs, grid, align_corners=True)) + .squeeze(1) + .long() + ) + + +class NetHackNet(nn.Module): + def __init__( + self, + observation_shape, + num_actions, + use_lstm, + embedding_dim=32, + crop_dim=9, + num_layers=5, + goal=False + ): + super(NetHackNet, self).__init__() + self.goal = goal + + self.glyph_shape = observation_shape["glyphs"].shape # (21, 79) + self.blstats_size = observation_shape["blstats"].shape[0] # 27 + if self.goal: + self.goal_size = observation_shape["goal"].shape[0] # 785? + + self.num_actions = num_actions # 23 + self.use_lstm = use_lstm + + self.H = self.glyph_shape[0] # 21 + self.W = self.glyph_shape[1] # 79 + + self.k_dim = embedding_dim # 32 + self.h_dim = 512 + + self.crop_dim = crop_dim # 9 + + self.crop = Crop(self.H, self.W, self.crop_dim, self.crop_dim) + + self.embed = nn.Embedding(nethack.MAX_GLYPH, self.k_dim) # 5976, 32 + + K = embedding_dim # number of input filters + F = 3 # filter dimensions + S = 1 # stride + P = 1 # padding + M = 16 # number of intermediate filters + Y = 8 # number of output filters + L = num_layers # number of convnet layers # 5 + + in_channels = [K] + [M] * (L - 1) # [32, 16, 16, 16, 16] + out_channels = [M] * (L - 1) + [Y] # [16, 16, 16, 16, 8] + + def interleave(xs, ys): + return [val for pair in zip(xs, ys) for val in pair] + + conv_extract = [ + nn.Conv2d( + in_channels=in_channels[i], + out_channels=out_channels[i], + kernel_size=(F, F), + stride=S, + padding=P, + ) + for i in range(L) + ] + + # Create a sequential net of alternating Conv2d and ELU layers + self.extract_representation = nn.Sequential( + *interleave(conv_extract, [nn.ELU()] * len(conv_extract)) + ) + + # CNN crop model. + conv_extract_crop = [ + nn.Conv2d( + in_channels=in_channels[i], + out_channels=out_channels[i], + kernel_size=(F, F), + stride=S, + padding=P, + ) + for i in range(L) + ] + + # Create a sequential net of alternating Conv2d and ELU layers + self.extract_crop_representation = nn.Sequential( + *interleave(conv_extract_crop, [nn.ELU()] * len(conv_extract)) + ) + + # Blstats output dim + out_dim = self.k_dim # 32 + # Map glyphs output dim + out_dim += self.H * self.W * Y # 32 + 21 * 79 * 8 = 13304 + # Cropped map glyphs output dim + out_dim += self.crop_dim**2 * Y # 13304 + 9 * 9 * 8 = 13952 + if self.goal: + # Goal output dim + out_dim += int(self.goal_size / 2) # 13952 + 392 = 14344 + + # Blstats encoding 27 to 32 + self.embed_blstats = nn.Sequential( + nn.Linear(self.blstats_size, self.k_dim), + nn.ReLU(), + nn.Linear(self.k_dim, self.k_dim), + nn.ReLU(), + ) + + # Goal encoding 27 to 32 + if self.goal: + goal_dim = int(self.goal_size / 2) + self.embed_goals = nn.Sequential( + nn.Linear(self.goal_size, goal_dim), + nn.ReLU(), + nn.Linear(goal_dim, goal_dim), + nn.ReLU(), + ) + + # Fully connected from embedding to policy 13304 to 512 + self.fc = nn.Sequential( + nn.Linear(out_dim, self.h_dim), + nn.ReLU(), + nn.Linear(self.h_dim, self.h_dim), + nn.ReLU(), + ) + + if self.use_lstm: + # LSTM 512 to 512 + self.core = nn.LSTM(self.h_dim, self.h_dim, num_layers=1) + + # Policy 512 to 23 + self.policy = nn.Linear(self.h_dim, self.num_actions) + # Baseline 512 to 1 + self.baseline = nn.Linear(self.h_dim, 1) + + def initial_state(self, batch_size=1): + if self.use_lstm: + return tuple( + torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size) + for _ in range(2) + ) + return tuple() + + def _select(self, embed, x): + # Work around slow backward pass of nn.Embedding, see + # https://github.com/pytorch/pytorch/issues/24912 + out = embed.weight.index_select(0, x.reshape(-1)) + return out.reshape(x.shape + (-1,)) + + def forward(self, env_outputs, core_state): + # Set up glyphs + # -- [T x B x H x W] + glyphs = env_outputs["glyphs"] # [1, 1, 21, 79] + T, B, *_ = glyphs.shape + # -- [B' x H x W] + glyphs = torch.flatten(glyphs, 0, 1) # Merge time and batch. # [1, 21, 79] + # -- [B x H x W] + glyphs = glyphs.long() + + # Bottom line stats and extract coordinates + # -- [T x B x F] + blstats = env_outputs["blstats"] # [1, 1, 27] + # -- [B' x F] + blstats = blstats.view(T * B, -1).float() # [1, 27] + # -- [B x 2] x,y coordinates + coordinates = blstats[:, :2] # [1, 2] + # TODO ??? + # coordinates[:, 0].add_(-1) + # -- [B x F] + # TODO: Remove? + blstats = blstats.view(T * B, -1).float() # [1, 27] + # -- [B x K] + blstats_emb = self.embed_blstats(blstats) # [1, 32] + assert blstats_emb.shape[0] == T * B + reps = [blstats_emb] + + # Cropped map glyphs + # -- [B x H' x W'] + # Crop glyphs observation around player x,y coordinates + crop = self.crop(glyphs, coordinates) # [1, 9, 9] + # -- [B x H' x W' x K] + # Embed glyphs + crop_emb = self._select(self.embed, crop) # [1, 9, 9, 32] + # CNN crop model. + # -- [B x K x W' x H'] + crop_emb = crop_emb.transpose(1, 3) # -- TODO: slow? # [1, 32, 9, 9] + # -- [B x W' x H' x K] + # Convolutional pass to get representation of cropped region # [1, 8, 9, 9] + crop_rep = self.extract_crop_representation(crop_emb) + # -- [B x K'] + crop_rep = crop_rep.view(T * B, -1) # [1, 648] + assert crop_rep.shape[0] == T * B + reps.append(crop_rep) + + # -- [B x H x W x K] + # Full map glyphs + glyphs_emb = self._select(self.embed, glyphs) # [1, 21, 79, 32] + # -- [B x K x W x H] + glyphs_emb = glyphs_emb.transpose(1, 3) # -- TODO: slow? # [1, 32, 79, 21] + # -- [B x W x H x K] + glyphs_rep = self.extract_representation(glyphs_emb) # [1, 8, 79, 21] + # -- [B x K'] + glyphs_rep = glyphs_rep.view(T * B, -1) # [1, 13272] + assert glyphs_rep.shape[0] == T * B + # -- [B x K''] + reps.append(glyphs_rep) + + # TODO: Goals + if self.goal: + # -- [T x B x F] + goals = env_outputs["goal"] # [1, 1, 785] + # -- [B' x F] + goals = goals.view(T * B, -1).float() # [1, 785] + # -- [B x F] + # TODO: Remove? + goals = goals.view(T * B, -1).float() # [1, 785] + # -- [B x K] + goals_emb = self.embed_goals(goals) # [1, 32] + assert goals_emb.shape[0] == T * B + reps.append(goals_emb) + + # [32, 648, 13272] + st = torch.cat(reps, dim=1) + + # -- [B x K] + st = self.fc(st) # [1, 512] + + if self.use_lstm: + core_input = st.view(T, B, -1) + core_output_list = [] + notdone = (~env_outputs["done"]).float() + for input, nd in zip(core_input.unbind(), notdone.unbind()): + # Reset core state to zero whenever an episode ended. + # Make `done` broadcastable with (num_layers, B, hidden_size) + # states: + nd = nd.view(1, -1, 1) + core_state = tuple(nd * s for s in core_state) + output, core_state = self.core(input.unsqueeze(0), core_state) + core_output_list.append(output) + core_output = torch.flatten(torch.cat(core_output_list), 0, 1) + else: + core_output = st + + # -- [B x A] + policy_logits = self.policy(core_output) + # -- [B x A] + baseline = self.baseline(core_output) + + if self.training: + action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) + else: + # Don't sample when testing. + action = torch.argmax(policy_logits, dim=1) + + policy_logits = policy_logits.view(T, B, self.num_actions) + baseline = baseline.view(T, B) + action = action.view(T, B) + + return ( + dict(policy_logits=policy_logits, baseline=baseline, action=action), + core_state, + ) + + +Net = NetHackNet + + +def main(flags, wandb=None): + if flags.mode == "train": + train(flags, wandb_run=wandb) + else: + test(flags) + + +if __name__ == "__main__": + flags = parse_args() + + if flags.profile: + import cProfile + from pstats import Stats + pr = cProfile.Profile() + pr.enable() + + wandb_run = None + if flags.exp_name: + wandb_run = wandb.init( + project="syllabus", + entity="", + config=flags, + save_code=True, + name=flags.exp_name, + resume="allow", + id=f"{flags.exp_name}-{flags.wandb_id}" + ) + main(flags, wandb=wandb_run) + + if flags.profile: + pr.disable() + stats = Stats(pr) + stats.sort_stats('cumtime').print_stats(200) diff --git a/examples/task_wrappers/__init__.py b/examples/task_wrappers/__init__.py new file mode 100644 index 00000000..7d04cff4 --- /dev/null +++ b/examples/task_wrappers/__init__.py @@ -0,0 +1,23 @@ +import warnings + +from .cartpole_task_wrapper import CartPoleTaskWrapper + +try: + from .minigrid_task_wrapper import MinigridTaskWrapper +except ImportError as e: + warnings.warn(f"Unable to import the following minigrid dependencies: {e.name}") + +try: + from .minihack_task_wrapper import MinihackTaskWrapper +except ImportError as e: + warnings.warn(f"Unable to import the following minihack dependencies: {e.name}") + +try: + from .nethack_wrappers import NethackTaskWrapper, RenderCharImagesWithNumpyWrapperV2 +except ImportError as e: + warnings.warn(f"Unable to import the following nle dependencies: {e.name}") + +try: + from .procgen_task_wrapper import ProcgenTaskWrapper +except ImportError as e: + warnings.warn(f"Unable to import the following procgen dependencies: {e.name}") diff --git a/examples/task_wrappers/cartpole_task_wrapper.py b/examples/task_wrappers/cartpole_task_wrapper.py new file mode 100644 index 00000000..4e590632 --- /dev/null +++ b/examples/task_wrappers/cartpole_task_wrapper.py @@ -0,0 +1,24 @@ +from gymnasium.spaces import Box + +from syllabus.core import TaskWrapper +from syllabus.task_space import TaskSpace + + +class CartPoleTaskWrapper(TaskWrapper): + def __init__(self, env): + super().__init__(env) + self.task_space = TaskSpace(Box(-0.3, 0.3, shape=(2,))) + self.task = (-0.02, 0.02) + self.total_reward = 0 + + def reset(self, *args, **kwargs): + self.total_reward = 0 + if "new_task" in kwargs: + new_task = kwargs.pop("new_task") + self.task = new_task + return self.env.reset(options={"low": self.task[0], "high": self.task[1]}) + + def _task_completion(self, obs, rew, term, trunc, info) -> float: + # Return percent of optimal reward + self.total_reward += rew + return self.total_reward / 500.0 diff --git a/examples/task_wrappers/minigrid_task_wrapper.py b/examples/task_wrappers/minigrid_task_wrapper.py new file mode 100644 index 00000000..de2e699b --- /dev/null +++ b/examples/task_wrappers/minigrid_task_wrapper.py @@ -0,0 +1,88 @@ +""" Task wrapper that can select a new MiniGrid task on reset. """ +import gymnasium as gym +import numpy as np +from syllabus.core import TaskWrapper +from syllabus.task_space import TaskSpace + + +class MinigridTaskWrapper(TaskWrapper): + """ + This wrapper allows you to change the task of an NLE environment. + """ + def __init__(self, env: gym.Env): + super().__init__(env) + try: + from gym_minigrid.minigrid import COLOR_TO_IDX, OBJECT_TO_IDX + except ImportError: + warnings.warn("Unable to import gym_minigrid.") + + self.observation_space = gym.spaces.Box( + low=0, + high=255, + shape=(self.env.width, self.env.height, 3), # number of cells + dtype='uint8' + ) + m, n, c = self.observation_space.shape + self.observation_space = gym.spaces.Box( + self.observation_space.low[0, 0, 0], + self.observation_space.high[0, 0, 0], + [c, m, n], + dtype=self.observation_space.dtype) + + # Set up task space + self.task_space = TaskSpace(gym.spaces.Discrete(4000), list(np.arange(4000))) + self.task = None + + def reset(self, new_task=None, **kwargs): + """ + Resets the environment along with all available tasks, and change the current task. + + This ensures that all instance variables are reset, not just the ones for the current task. + We do this efficiently by keeping track of which reset functions have already been called, + since very few tasks override reset. If new_task is provided, we change the task before + calling the final reset. + """ + # Change task if new one is provided + if new_task is not None: + self.change_task(new_task) + + self.done = False + self.episode_return = 0 + + return self.observation(self.env.reset(**kwargs)["image"]) + + def change_task(self, new_task: int): + """ + Change task by directly editing environment class. + + Ignores requests for unknown tasks or task changes outside of a reset. + """ + seed = int(new_task) + self.task = seed + self.env.seed(seed) + + def step(self, action): + """ + Step through environment and update task completion. + """ + # assert self._elapsed_steps is not None, "Cannot call env.step() before calling reset()" + obs, rew, term, trunc, info = self.env.step(action) + obs = self.observation(obs["image"]) + + self.episode_return += rew + self.done = term or trunc + info["task_completion"] = self._task_completion(obs, rew, term, trunc, info) + + return obs, rew, term, trunc, info + + def observation(self, obs): + env = self.unwrapped + full_grid = env.grid.encode() + full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([ + OBJECT_TO_IDX['agent'], + COLOR_TO_IDX['red'], + env.agent_dir + ]) + + obs = full_grid + return obs.transpose(2, 0, 1) diff --git a/examples/task_wrappers/minihack_task_wrapper.py b/examples/task_wrappers/minihack_task_wrapper.py new file mode 100644 index 00000000..f7f61be2 --- /dev/null +++ b/examples/task_wrappers/minihack_task_wrapper.py @@ -0,0 +1,30 @@ +""" Task wrapper for NLE that can change tasks at reset using the NLE's task definition format. """ +import gymnasium as gym +from gymnasium import spaces +from syllabus.core import TaskWrapper +from syllabus.task_space import TaskSpace + + +class MinihackTaskWrapper(TaskWrapper): + """ + This wrapper simply changes the seed of a Minigrid environment. + """ + def __init__(self, env: gym.Env): + super().__init__(env) + self.env = env + + self.task: str = 1 + + # Task completion metrics + self.episode_return = 0 + self.task_space = TaskSpace(spaces.Discrete(1000), list(range(1000))) + + def reset(self, new_task: int = None, **kwargs): + # Change task if new one is provided + # if new_task is not None: + # self.change_task(new_task) + + self.episode_return = 0 + self.current_task = new_task + self.env.seed(new_task) + return self.observation(self.env.reset(**kwargs)) diff --git a/examples/task_wrappers/nethack_wrappers.py b/examples/task_wrappers/nethack_wrappers.py new file mode 100644 index 00000000..ba0c2cae --- /dev/null +++ b/examples/task_wrappers/nethack_wrappers.py @@ -0,0 +1,493 @@ +""" Task wrapper for NLE that can change tasks at reset using the NLE's task definition format. """ +import os +import time +from typing import Any, Dict, List, Tuple + +import cv2 +import gymnasium as gym +import numpy as np +# import render_utils +from gymnasium.utils.step_api_compatibility import step_api_compatibility +from nle import nethack +from nle.env import base +from nle.env.tasks import (NetHackChallenge, NetHackEat, NetHackGold, + NetHackOracle, NetHackScore, NetHackScout, + NetHackStaircase, NetHackStaircasePet) +from numba import njit +from PIL import Image, ImageDraw, ImageFont +from shimmy.openai_gym_compatibility import GymV21CompatibilityV0 + +from syllabus.core import TaskWrapper +from syllabus.task_space import TaskSpace + + +class NethackTaskWrapper(TaskWrapper): + """ + This wrapper allows you to change the task of an NLE environment. + + This wrapper was designed to meet two goals. + 1. Allow us to change the task of the NLE environment at the start of an episode + 2. Allow us to use the predefined NLE task definitions without copying/modifying their code. + This makes it easier to integrate with other work on nethack tasks or curricula. + + Each task is defined as a subclass of the NLE, so you need to cast and reinitialize the + environment to change its task. This wrapper manipulates the __class__ property to achieve this, + but does so in a safe way. Specifically, we ensure that the instance variables needed for each + task are available and reset at the start of the episode regardless of which task is active. + """ + def __init__( + self, + env: gym.Env, + additional_tasks: List[base.NLE] = None, + use_default_tasks: bool = True, + env_kwargs: Dict[str, Any] = {}, + wrappers: List[Tuple[gym.Wrapper, List[Any], Dict[str, Any]]] = None + ): + super().__init__(env) + self.env = env + self.task = NetHackScore + self._init_kwargs = env_kwargs + if self.env.__class__ == NetHackChallenge: + self._no_progress_timeout = self._init_kwargs.pop("no_progress_timeout", 150) + + # This is set to False during reset + self.done = True + + # Add nethack tasks provided by the base NLE + task_list: List[base.NLE] = [] + if use_default_tasks: + task_list = [ + NetHackScore, + NetHackStaircase, + NetHackStaircasePet, + NetHackOracle, + NetHackGold, + NetHackEat, + NetHackScout, + ] + + # Add in custom nethack tasks + if additional_tasks: + for task in additional_tasks: + assert isinstance(task, base.NLE), "Env must subclass the base NLE" + task_list.append(task) + + self.task_list = task_list + gym_space = gym.spaces.Discrete(len(self.task_list)) + self.task_space = TaskSpace(gym_space, task_list) + + # Add goal space to observation + # self.observation_space = copy.deepcopy(self.env.observation_space) + # self.observation_space["goal"] = spaces.MultiBinary(len(self.task_list)) + + # Task completion metrics + self.episode_return = 0 + + # TODO: Deal with wrappers + self._nethack_env = self.env + while self._nethack_env.__class__ not in self.task_list and self._nethack_env.__class__ != NetHackChallenge: + if self._nethack_env.__class__ == GymV21CompatibilityV0: + self._nethack_env = self._nethack_env.gym_env + else: + self._nethack_env = self._nethack_env.env + + # Initialize missing instance variables + self._nethack_env.oracle_glyph = None + + def _task_name(self, task): + return task.__name__ + + def reset(self, new_task=None, **kwargs): + """ + Resets the environment along with all available tasks, and change the current task. + + This ensures that all instance variables are reset, not just the ones for the current task. + We do this efficiently by keeping track of which reset functions have already been called, + since very few tasks override reset. If new_task is provided, we change the task before + calling the final reset. + """ + # Change task if new one is provided + new_task = np.random.choice(self.task_list) + if new_task is not None: + self.change_task(new_task) + + self.done = False + self.episode_return = 0 + + return self.observation(self.env.reset(**kwargs)) + + def change_task(self, new_task: int): + """ + Change task by directly editing environment class. + + Ignores requests for unknown tasks or task changes outside of a reset. + """ + # Ignore new task if mid episode + if self.task.__init__ != new_task.__init__ and not self.done: + print(f"Given task {self._task_name(new_task)} needs to be reinitialized.\ + Ignoring request to change task and keeping {self.task.__name__}") + return + + # Ignore if task is unknown + if new_task not in self.task_list: + print(f"Given task {new_task} not in task list.\ + Ignoring request to change task and keeping {self.env.__class__.__name__}") + return + + # Update current task + self.task = new_task + self._nethack_env.__class__ = new_task + + # If task requires reinitialization + # if type(self._nethack_env).__init__ != NetHackScore.__init__: + # self._nethack_env.__init__(actions=nethack.ACTIONS, **self._init_kwargs) + + def _encode_goal(self): + goal_encoding = np.zeros(len(self.task_list)) + index = self.task_list.index(self.task) + goal_encoding[index] = 1 + return goal_encoding + + def observation(self, observation): + """ + Parses current inventory and new items gained this timestep from the observation. + Returns a modified observation. + """ + # Add goal to observation + # observation['goal'] = self._encode_goal() + return observation + + def _task_completion(self, obs, rew, term, trunc, info): + # TODO: Add real task completion metrics + completion = 0.0 + if self.task == 0: + completion = self.episode_return / 1000 + elif self.task == 1: + completion = self.episode_return + elif self.task == 2: + completion = self.episode_return + elif self.task == 3: + completion = self.episode_return + elif self.task == 4: + completion = self.episode_return / 1000 + elif self.task == 5: + completion = self.episode_return / 10 + elif self.task == 6: + completion = self.episode_return / 100 + + return min(max(completion, 0.0), 1.0) + + def step(self, action): + """ + Step through environment and update task completion. + """ + obs, rew, term, trunc, info = step_api_compatibility(self.env.step(action), output_truncation_bool=True) + # self.episode_return += rew + self.done = term or trunc + info["task_completion"] = self._task_completion(obs, rew, term, trunc, info) + return self.observation(obs), rew, term, trunc, info + + +SMALL_FONT_PATH = os.path.abspath("syllabus/examples/utils/Hack-Regular.ttf") + +# Mapping of 0-15 colors used. +# Taken from bottom image here. It seems about right +# https://i.stack.imgur.com/UQVe5.png +COLORS = [ + "#000000", + "#800000", + "#008000", + "#808000", + "#000080", + "#800080", + "#008080", + "#808080", # - flipped these ones around + "#C0C0C0", # | the gray-out dull stuff + "#FF0000", + "#00FF00", + "#FFFF00", + "#0000FF", + "#FF00FF", + "#00FFFF", + "#FFFFFF", +] + + +@njit +def _tile_characters_to_image( + out_image, + chars, + colors, + output_height_chars, + output_width_chars, + char_array, + offset_h, + offset_w, +): + """ + Build an image using cached images of characters in char_array to out_image + """ + char_height = char_array.shape[3] + char_width = char_array.shape[4] + for h in range(output_height_chars): + h_char = h + offset_h + # Stuff outside boundaries is not visible, so + # just leave it black + if h_char < 0 or h_char >= chars.shape[0]: + continue + for w in range(output_width_chars): + w_char = w + offset_w + if w_char < 0 or w_char >= chars.shape[1]: + continue + char = chars[h_char, w_char] + color = colors[h_char, w_char] + h_pixel = h * char_height + w_pixel = w * char_width + out_image[ + :, h_pixel : h_pixel + char_height, w_pixel : w_pixel + char_width + ] = char_array[char, color] + + +def _initialize_char_array(font_size, rescale_font_size): + """Draw all characters in PIL and cache them in numpy arrays + + if rescale_font_size is given, assume it is (width, height) + + Returns a np array of (num_chars, num_colors, char_height, char_width, 3) + """ + try: + font = ImageFont.truetype(SMALL_FONT_PATH, font_size) + except OSError as e: + raise ValueError("Change SMALL_FONT_PATH to point to syllabus/examples/utils/Hack-Regular.ttf") from e + + dummy_text = "".join( + [(chr(i) if chr(i).isprintable() else " ") for i in range(256)] + ) + _, _, image_width, image_height = font.getbbox(dummy_text) + # Above can not be trusted (or its siblings).... + image_width = int(np.ceil(image_width / 256) * 256) + + char_width = rescale_font_size[0] + char_height = rescale_font_size[1] + + char_array = np.zeros((256, 16, char_height, char_width, 3), dtype=np.uint8) + image = Image.new("RGB", (image_width, image_height)) + image_draw = ImageDraw.Draw(image) + for color_index in range(16): + image_draw.rectangle((0, 0, image_width, image_height), fill=(0, 0, 0)) + image_draw.text((0, 0), dummy_text, fill=COLORS[color_index], spacing=0) + + arr = np.array(image).copy() + arrs = np.array_split(arr, 256, axis=1) + for char_index in range(256): + char = arrs[char_index] + if rescale_font_size: + char = cv2.resize(char, rescale_font_size, interpolation=cv2.INTER_AREA) + char_array[char_index, color_index] = char + return char_array + + +class RenderCharImagesWithNumpyWrapper(gym.Wrapper): + """ + Render characters as images, using PIL to render characters like we humans see on screen + but then some caching and numpy stuff to speed up things. + + To speed things up, crop image around the player. + """ + + def __init__( + self, + env, + font_size=9, + crop_size=12, + rescale_font_size=(6, 6), + blstats_cursor=False, + ): + super().__init__(env) + self.char_array = _initialize_char_array(font_size, rescale_font_size) + self.char_height = self.char_array.shape[2] + self.char_width = self.char_array.shape[3] + # Transpose for CHW + self.char_array = self.char_array.transpose(0, 1, 4, 2, 3) + + self.crop_size = crop_size + self.blstats_cursor = blstats_cursor + + self.half_crop_size = crop_size // 2 + self.output_height_chars = crop_size + self.output_width_chars = crop_size + self.chw_image_shape = ( + 3, + self.output_height_chars * self.char_height, + self.output_width_chars * self.char_width, + ) + + obs_spaces = { + "screen_image": gym.spaces.Box( + low=0, high=255, shape=self.chw_image_shape, dtype=np.uint8 + ) + } + obs_spaces.update( + [ + (k, self.env.observation_space[k]) + for k in self.env.observation_space + if k not in ["tty_chars", "tty_colors"] + ] + ) + self.observation_space = gym.spaces.Dict(obs_spaces) + + def _render_text_to_image(self, obs): + chars = obs["tty_chars"] + colors = obs["tty_colors"] + offset_w = 0 + offset_h = 0 + if self.crop_size: + # Center around player + if self.blstats_cursor: + center_x, center_y = obs["blstats"][:2] + else: + center_y, center_x = obs["tty_cursor"] + offset_h = center_y - self.half_crop_size + offset_w = center_x - self.half_crop_size + + out_image = np.zeros(self.chw_image_shape, dtype=np.uint8) + + _tile_characters_to_image( + out_image=out_image, + chars=chars, + colors=colors, + output_height_chars=self.output_height_chars, + output_width_chars=self.output_width_chars, + char_array=self.char_array, + offset_h=offset_h, + offset_w=offset_w, + ) + + obs["screen_image"] = out_image + return obs + + def step(self, action): + obs, reward, done, info = self.env.step(action) + obs = self._render_text_to_image(obs) + return obs, reward, done, info + + def reset(self): + obs = self.env.reset() + obs = self._render_text_to_image(obs) + return obs + + +class RenderCharImagesWithNumpyWrapperV2(gym.Wrapper): + """ + Same as V1, but simpler and faster. + """ + + def __init__( + self, + env, + font_size=9, + crop_size=12, + rescale_font_size=(6, 6), + ): + super().__init__(env) + self.char_array = _initialize_char_array(font_size, rescale_font_size) + self.char_height = self.char_array.shape[2] + self.char_width = self.char_array.shape[3] + # Transpose for CHW + self.char_array = self.char_array.transpose(0, 1, 4, 2, 3) + self.char_array = np.ascontiguousarray(self.char_array) + self.crop_size = crop_size + + crop_rows = crop_size or nethack.nethack.TERMINAL_SHAPE[0] + crop_cols = crop_size or nethack.nethack.TERMINAL_SHAPE[1] + + self.chw_image_shape = ( + 3, + crop_rows * self.char_height, + crop_cols * self.char_width, + ) + + obs_spaces = { + "screen_image": gym.spaces.Box( + low=0, high=255, shape=self.chw_image_shape, dtype=np.uint8 + ) + } + obs_spaces.update( + [ + (k, self.env.observation_space[k]) + for k in self.env.observation_space + # if k not in ["tty_chars", "tty_colors"] + ] + ) + self.observation_space = gym.spaces.Dict(obs_spaces) + + def _populate_obs(self, obs): + screen = np.zeros(self.chw_image_shape, order="C", dtype=np.uint8) + render_utils.render_crop( + obs["tty_chars"], + obs["tty_colors"], + obs["tty_cursor"], + self.char_array, + screen, + crop_size=self.crop_size, + ) + obs["screen_image"] = screen + + def step(self, action): + obs, reward, term, trunc, info = self.env.step(action) + self._populate_obs(obs) + return obs, reward, term, trunc, info + + def reset(self): + obs, info = self.env.reset() + self._populate_obs(obs) + return obs, info + + +if __name__ == "__main__": + def run_episode(env, task: str = None, verbose=1): + env.reset(new_task=task) + task_name = type(env.unwrapped).__name__ + term = trunc = False + ep_rew = 0 + while not (term or trunc): + action = env.action_space.sample() + _, rew, term, trunc, _ = env.step(action) + ep_rew += rew + if verbose: + print(f"Episodic reward for {task_name}: {ep_rew}") + + print("Testing NethackTaskWrapper") + N_EPISODES = 100 + + # Initialize NLE + nethack_env = NetHackScore() + nethack_env = GymV21CompatibilityV0(env=nethack_env) + + nethack_task_env = NethackTaskWrapper(nethack_env) + + task_list = [ + NetHackScore, + NetHackStaircase, + NetHackStaircasePet, + NetHackOracle, + NetHackGold, + NetHackEat, + NetHackScout, + ] + + start_time = time.time() + + for _ in range(N_EPISODES): + run_episode(nethack_task_env, verbose=0) + + end_time = time.time() + print(f"Run time same task: {end_time - start_time}") + start_time = time.time() + + for i in range(N_EPISODES): + nethack_task = task_list[i % 7] + run_episode(nethack_task_env, task=nethack_task, verbose=0) + + end_time = time.time() + print(f"Run time swapping tasks: {end_time - start_time}") diff --git a/examples/task_wrappers/pistonball_task_wrapper.py b/examples/task_wrappers/pistonball_task_wrapper.py new file mode 100644 index 00000000..d60abb7a --- /dev/null +++ b/examples/task_wrappers/pistonball_task_wrapper.py @@ -0,0 +1,35 @@ +""" Task wrapper for NLE that can change tasks at reset using the NLE's task definition format. """ +import gymnasium as gym +from gymnasium import spaces +from pettingzoo.butterfly import pistonball_v6 +from syllabus.core import PettingZooTaskWrapper +from syllabus.task_space import TaskSpace + + +class PistonballTaskWrapper(PettingZooTaskWrapper): + """ + This wrapper simply changes the seed of a Minigrid environment. + """ + def __init__(self, env: gym.Env): + super().__init__(env) + self.env = env + self.env.unwrapped.task: str = 1 + + # Task completion metrics + self.episode_return = 0 + self.task_space = TaskSpace(spaces.Discrete(11), list(range(11))) # 0.1 - 1.0 friction + + def reset(self, new_task: int = None, **kwargs): + # Change task if new one is provided + # if new_task is not None: + # self.change_task(new_task) + + self.episode_return = 0 + if new_task is not None: + task = new_task / 10 + # Inject current_task into the environment + self.env = pistonball_v6.parallel_env( + ball_friction=task, continuous=False, max_cycles=125 + ) + self.env.unwrapped.task = new_task + return self.observation(self.env.reset(**kwargs)) diff --git a/examples/task_wrappers/procgen_task_wrapper.py b/examples/task_wrappers/procgen_task_wrapper.py new file mode 100644 index 00000000..2296fd58 --- /dev/null +++ b/examples/task_wrappers/procgen_task_wrapper.py @@ -0,0 +1,86 @@ +import gymnasium as gym +import numpy as np +from syllabus.core import TaskWrapper +from syllabus.task_space import TaskSpace + + +PROCGEN_RETURN_BOUNDS = { + "coinrun": (5, 10), + "starpilot": (2.5, 64), + "caveflyer": (3.5, 12), + "dodgeball": (1.5, 19), + "fruitbot": (-1.5, 32.4), + "chaser": (0.5, 13), + "miner": (1.5, 13), + "jumper": (3, 10), + "leaper": (3, 10), + "maze": (5, 10), + "bigfish": (1, 40), + "heist": (3.5, 10), + "climber": (2, 12.6), + "plunder": (4.5, 30), + "ninja": (3.5, 10), + "bossfight": (0.5, 13), +} + + +class ProcgenTaskWrapper(TaskWrapper): + """ + This wrapper allows you to change the task of an NLE environment. + """ + def __init__(self, env: gym.Env, env_id, seed=0): + super().__init__(env) + self.task_space = TaskSpace(gym.spaces.Discrete(200), list(np.arange(0, 200))) + self.env_id = env_id + self.task = seed + self.seed(seed) + self.episode_return = 0 + + self.observation_space = self.env.observation_space + + def seed(self, seed): + self.env.gym_env.unwrapped._venv.seed(int(seed), 0) + + def reset(self, new_task=None, **kwargs): + """ + Resets the environment along with all available tasks, and change the current task. + + This ensures that all instance variables are reset, not just the ones for the current task. + We do this efficiently by keeping track of which reset functions have already been called, + since very few tasks override reset. If new_task is provided, we change the task before + calling the final reset. + """ + self.episode_return = 0.0 + + # Change task if new one is provided + if new_task is not None: + self.change_task(new_task) + + obs, info = self.env.reset(**kwargs) + return self.observation(obs), info + + def change_task(self, new_task: int): + """ + Change task by directly editing environment class. + + Ignores requests for unknown tasks or task changes outside of a reset. + """ + seed = int(new_task) + self.task = seed + self.seed(seed) + + def step(self, action): + """ + Step through environment and update task completion. + """ + obs, rew, term, trunc, info = self.env.step(action) + self.episode_return += rew + + env_min, env_max = PROCGEN_RETURN_BOUNDS[self.env_id] + normalized_return = (self.episode_return - env_min) / (env_max - env_min) + info["task_completion"] = normalized_return + + return self.observation(obs), rew, term, trunc, info + + def observation(self, obs): + return obs diff --git a/examples/training_scripts/cleanrl_procgen_centralplr.py b/examples/training_scripts/cleanrl_procgen_centralplr.py new file mode 100644 index 00000000..b848d693 --- /dev/null +++ b/examples/training_scripts/cleanrl_procgen_centralplr.py @@ -0,0 +1,536 @@ +""" An example applying Syllabus Prioritized Level Replay to Procgen. This code is based on https://github.com/facebookresearch/level-replay/blob/main/train.py + +NOTE: In order to efficiently change the seed of a procgen environment directly without reinitializing it, +we rely on Minqi Jiang's custom branch of procgen found here: https://github.com/minqi/procgen +""" +import argparse +import os +import random +import time +from collections import deque +from distutils.util import strtobool + +import gym as openai_gym +import gymnasium as gym +import numpy as np +import procgen # noqa: F401 +from procgen import ProcgenEnv +import torch +import torch.nn as nn +import torch.optim as optim +from shimmy.openai_gym_compatibility import GymV21CompatibilityV0 +from torch.utils.tensorboard import SummaryWriter + +from syllabus.core import MultiProcessingSyncWrapper, make_multiprocessing_curriculum +from syllabus.curricula import CentralizedPrioritizedLevelReplay, DomainRandomization, LearningProgressCurriculum, SequentialCurriculum +from syllabus.examples.models import ProcgenAgent +from syllabus.examples.task_wrappers import ProcgenTaskWrapper +from syllabus.examples.utils.vecenv import VecMonitor, VecNormalize, VecExtractDictObs + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="syllabus", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="weather to capture videos of the agent performances (check out `videos` folder)") + parser.add_argument("--logging-dir", type=str, default=".", + help="the base directory for logging and wandb storage.") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="starpilot", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=int(25e6), + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=5e-4, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=64, + help="the number of parallel game environments") + parser.add_argument("--num-steps", type=int, default=256, + help="the number of steps to run in each environment per policy rollout") + parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Use GAE for advantage computation") + parser.add_argument("--gamma", type=float, default=0.999, + help="the discount factor gamma") + parser.add_argument("--gae-lambda", type=float, default=0.95, + help="the lambda for the general advantage estimation") + parser.add_argument("--num-minibatches", type=int, default=8, + help="the number of mini-batches") + parser.add_argument("--update-epochs", type=int, default=3, + help="the K epochs to update the policy") + parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles advantages normalization") + parser.add_argument("--clip-coef", type=float, default=0.2, + help="the surrogate clipping coefficient") + parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--ent-coef", type=float, default=0.01, + help="coefficient of the entropy") + parser.add_argument("--vf-coef", type=float, default=0.5, + help="coefficient of the value function") + parser.add_argument("--max-grad-norm", type=float, default=0.5, + help="the maximum norm for the gradient clipping") + parser.add_argument("--target-kl", type=float, default=None, + help="the target KL divergence threshold") + + # Procgen arguments + parser.add_argument("--full-dist", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Train on full distribution of levels.") + + # Curriculum arguments + parser.add_argument("--curriculum", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will use curriculum learning") + parser.add_argument("--curriculum-method", type=str, default="plr", + help="curriculum method to use") + parser.add_argument("--num-eval-episodes", type=int, default=10, + help="the number of episodes to evaluate the agent on after each policy update.") + + args = parser.parse_args() + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + # fmt: on + return args + + +PROCGEN_RETURN_BOUNDS = { + "coinrun": (5, 10), + "starpilot": (2.5, 64), + "caveflyer": (3.5, 12), + "dodgeball": (1.5, 19), + "fruitbot": (-1.5, 32.4), + "chaser": (0.5, 13), + "miner": (1.5, 13), + "jumper": (3, 10), + "leaper": (3, 10), + "maze": (5, 10), + "bigfish": (1, 40), + "heist": (3.5, 10), + "climber": (2, 12.6), + "plunder": (4.5, 30), + "ninja": (3.5, 10), + "bossfight": (0.5, 13), +} + + +def make_env(env_id, seed, curriculum=None, start_level=0, num_levels=1): + def thunk(): + env = openai_gym.make(f"procgen-{env_id}-v0", distribution_mode="easy", start_level=start_level, num_levels=num_levels) + env = GymV21CompatibilityV0(env=env) + if curriculum is not None: + env = ProcgenTaskWrapper(env, env_id, seed=seed) + env = MultiProcessingSyncWrapper( + env, + curriculum.get_components(), + update_on_step=False, + task_space=env.task_space, + ) + return env + return thunk + + +def wrap_vecenv(vecenv): + vecenv.is_vector_env = True + vecenv = VecMonitor(venv=vecenv, filename=None, keep_buf=100) + vecenv = VecNormalize(venv=vecenv, ob=False, ret=True) + return vecenv + + +def full_level_replay_evaluate( + env_name, + policy, + num_episodes, + device, + num_levels=1 # Not used +): + policy.eval() + + eval_envs = ProcgenEnv( + num_envs=args.num_eval_episodes, env_name=env_name, num_levels=1, start_level=0, distribution_mode="easy", paint_vel_info=False + ) + eval_envs = VecExtractDictObs(eval_envs, "rgb") + eval_envs = wrap_vecenv(eval_envs) + + # Seed environments + seeds = [int.from_bytes(os.urandom(3), byteorder="little") for _ in range(num_episodes)] + for i, seed in enumerate(seeds): + eval_envs.seed(seed, i) + + eval_obs, _ = eval_envs.reset() + eval_episode_rewards = [-1] * num_episodes + + while -1 in eval_episode_rewards: + with torch.no_grad(): + eval_action, _, _, _ = policy.get_action_and_value(torch.Tensor(eval_obs).to(device), deterministic=False) + + eval_obs, _, truncs, terms, infos = eval_envs.step(eval_action.cpu().numpy()) + for i, info in enumerate(infos): + if 'episode' in info.keys() and eval_episode_rewards[i] == -1: + eval_episode_rewards[i] = info['episode']['r'] + + mean_returns = np.mean(eval_episode_rewards) + stddev_returns = np.std(eval_episode_rewards) + env_min, env_max = PROCGEN_RETURN_BOUNDS[args.env_id] + normalized_mean_returns = (mean_returns - env_min) / (env_max - env_min) + policy.train() + return mean_returns, stddev_returns, normalized_mean_returns + + +def level_replay_evaluate( + env_name, + policy, + num_episodes, + device, + num_levels=0 +): + policy.eval() + + eval_envs = ProcgenEnv( + num_envs=args.num_eval_episodes, env_name=env_name, num_levels=num_levels, start_level=0, distribution_mode="easy", paint_vel_info=False + ) + eval_envs = VecExtractDictObs(eval_envs, "rgb") + eval_envs = wrap_vecenv(eval_envs) + eval_obs, _ = eval_envs.reset() + eval_episode_rewards = [-1] * num_episodes + + while -1 in eval_episode_rewards: + with torch.no_grad(): + eval_action, _, _, _ = policy.get_action_and_value(torch.Tensor(eval_obs).to(device), deterministic=False) + + eval_obs, _, truncs, terms, infos = eval_envs.step(eval_action.cpu().numpy()) + for i, info in enumerate(infos): + if 'episode' in info.keys() and eval_episode_rewards[i] == -1: + eval_episode_rewards[i] = info['episode']['r'] + + # print(eval_episode_rewards) + mean_returns = np.mean(eval_episode_rewards) + stddev_returns = np.std(eval_episode_rewards) + env_min, env_max = PROCGEN_RETURN_BOUNDS[args.env_id] + normalized_mean_returns = (mean_returns - env_min) / (env_max - env_min) + policy.train() + return mean_returns, stddev_returns, normalized_mean_returns + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + dir=args.logging_dir + ) + # wandb.run.log_code("./syllabus/examples") + + writer = SummaryWriter(os.path.join(args.logging_dir, "./runs/{run_name}")) + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + print("Device:", device) + + # Curriculum setup + curriculum = None + if args.curriculum: + sample_env = openai_gym.make(f"procgen-{args.env_id}-v0") + sample_env = GymV21CompatibilityV0(env=sample_env) + sample_env = ProcgenTaskWrapper(sample_env, args.env_id, seed=args.seed) + + # Intialize Curriculum Method + if args.curriculum_method == "plr": + print("Using prioritized level replay.") + curriculum = CentralizedPrioritizedLevelReplay( + sample_env.task_space, + num_steps=args.num_steps, + num_processes=args.num_envs, + gamma=args.gamma, + gae_lambda=args.gae_lambda, + task_sampler_kwargs_dict={"strategy": "value_l1"} + ) + elif args.curriculum_method == "dr": + print("Using domain randomization.") + curriculum = DomainRandomization(sample_env.task_space) + elif args.curriculum_method == "lp": + print("Using learning progress.") + curriculum = LearningProgressCurriculum(sample_env.task_space) + elif args.curriculum_method == "sq": + print("Using sequential curriculum.") + curricula = [] + stopping = [] + for i in range(199): + curricula.append(i + 1) + stopping.append("steps>=50000") + curricula.append(list(range(i + 1))) + stopping.append("steps>=50000") + curriculum = SequentialCurriculum(curricula, stopping[:-1], sample_env.task_space) + else: + raise ValueError(f"Unknown curriculum method {args.curriculum_method}") + curriculum = make_multiprocessing_curriculum(curriculum) + del sample_env + + # env setup + print("Creating env") + envs = gym.vector.AsyncVectorEnv( + [ + make_env( + args.env_id, + args.seed + i, + curriculum=curriculum if args.curriculum else None, + num_levels=1 if args.curriculum else 0 + ) + for i in range(args.num_envs) + ] + ) + envs = wrap_vecenv(envs) + + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + print("Creating agent") + agent = ProcgenAgent( + envs.single_observation_space.shape, + envs.single_action_space.n, + arch="large", + base_kwargs={'recurrent': False, 'hidden_size': 256} + ).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs, _ = envs.reset() + next_obs = torch.Tensor(next_obs).to(device) + next_done = torch.zeros(args.num_envs).to(device) + num_updates = args.total_timesteps // args.batch_size + episode_rewards = deque(maxlen=10) + completed_episodes = 0 + + for update in range(1, num_updates + 1): + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (update - 1.0) / num_updates + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += 1 * args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value = agent.get_action_and_value(next_obs) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, term, trunc, info = envs.step(action.cpu().numpy()) + done = np.logical_or(term, trunc) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) + completed_episodes += sum(done) + + for item in info: + if "episode" in item.keys(): + episode_rewards.append(item['episode']['r']) + print(f"global_step={global_step}, episodic_return={item['episode']['r']}") + writer.add_scalar("charts/episodic_return", item["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step) + if curriculum is not None: + curriculum.log_metrics(writer, global_step) + break + + # Syllabus curriculum update + if args.curriculum and args.curriculum_method == "plr": + with torch.no_grad(): + next_value = agent.get_value(next_obs) + tasks = envs.get_attr("task") + + update = { + "update_type": "on_demand", + "metrics": { + "value": value, + "next_value": next_value, + "rew": reward, + "dones": done, + "tasks": tasks, + }, + } + curriculum.update(update) + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + if args.gae: + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values + else: + returns = torch.zeros_like(rewards).to(device) + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + next_return = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + next_return = returns[t + 1] + returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return + advantages = returns - values + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None: + if approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # Evaluate agent + mean_eval_returns, stddev_eval_returns, normalized_mean_eval_returns = level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=0 + ) + full_mean_eval_returns, full_stddev_eval_returns, full_normalized_mean_eval_returns = full_level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=0 + ) + mean_train_returns, stddev_train_returns, normalized_mean_train_returns = level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=200 + ) + full_mean_train_returns, full_stddev_train_returns, full_normalized_mean_train_returns = full_level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=200 + ) + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("charts/episode_returns", np.mean(episode_rewards), global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + writer.add_scalar("test_eval/mean_episode_return", mean_eval_returns, global_step) + writer.add_scalar("test_eval/normalized_mean_eval_return", normalized_mean_eval_returns, global_step) + writer.add_scalar("test_eval/stddev_eval_return", stddev_eval_returns, global_step) + writer.add_scalar("test_eval/full_mean_episode_return", full_mean_eval_returns, global_step) + writer.add_scalar("test_eval/full_normalized_mean_eval_return", full_normalized_mean_eval_returns, global_step) + writer.add_scalar("test_eval/full_stddev_eval_return", full_stddev_eval_returns, global_step) + + writer.add_scalar("train_eval/mean_episode_return", mean_train_returns, global_step) + writer.add_scalar("train_eval/normalized_mean_train_return", normalized_mean_train_returns, global_step) + writer.add_scalar("train_eval/stddev_train_return", stddev_train_returns, global_step) + writer.add_scalar("train_eval/full_mean_episode_return", full_mean_train_returns, global_step) + writer.add_scalar("train_eval/full_normalized_mean_train_return", full_normalized_mean_train_returns, global_step) + writer.add_scalar("train_eval/full_stddev_train_return", full_stddev_train_returns, global_step) + + writer.add_scalar("curriculum/completed_episodes", completed_episodes, step) + + envs.close() + writer.close() diff --git a/examples/training_scripts/cleanrl_procgen_plr.py b/examples/training_scripts/cleanrl_procgen_plr.py new file mode 100644 index 00000000..dabcd500 --- /dev/null +++ b/examples/training_scripts/cleanrl_procgen_plr.py @@ -0,0 +1,528 @@ +""" An example applying Syllabus Prioritized Level Replay to Procgen. This code is based on https://github.com/facebookresearch/level-replay/blob/main/train.py + +NOTE: In order to efficiently change the seed of a procgen environment directly without reinitializing it, +we rely on Minqi Jiang's custom branch of procgen found here: https://github.com/minqi/procgen +""" +import argparse +import os +import random +import time +from collections import deque +from distutils.util import strtobool + +import gym as openai_gym +import gymnasium as gym +import numpy as np +import procgen # noqa: F401 +from procgen import ProcgenEnv +import torch +import torch.nn as nn +import torch.optim as optim +from shimmy.openai_gym_compatibility import GymV21CompatibilityV0 +from torch.utils.tensorboard import SummaryWriter + +from syllabus.core import MultiProcessingSyncWrapper, make_multiprocessing_curriculum +from syllabus.curricula import PrioritizedLevelReplay, DomainRandomization, LearningProgressCurriculum, SequentialCurriculum +from syllabus.examples.models import ProcgenAgent +from syllabus.examples.task_wrappers import ProcgenTaskWrapper +from syllabus.examples.utils.vecenv import VecMonitor, VecNormalize, VecExtractDictObs + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="syllabus", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="weather to capture videos of the agent performances (check out `videos` folder)") + parser.add_argument("--logging-dir", type=str, default=".", + help="the base directory for logging and wandb storage.") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="starpilot", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=int(25e6), + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=5e-4, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=64, + help="the number of parallel game environments") + parser.add_argument("--num-steps", type=int, default=256, + help="the number of steps to run in each environment per policy rollout") + parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Use GAE for advantage computation") + parser.add_argument("--gamma", type=float, default=0.999, + help="the discount factor gamma") + parser.add_argument("--gae-lambda", type=float, default=0.95, + help="the lambda for the general advantage estimation") + parser.add_argument("--num-minibatches", type=int, default=8, + help="the number of mini-batches") + parser.add_argument("--update-epochs", type=int, default=3, + help="the K epochs to update the policy") + parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles advantages normalization") + parser.add_argument("--clip-coef", type=float, default=0.2, + help="the surrogate clipping coefficient") + parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--ent-coef", type=float, default=0.01, + help="coefficient of the entropy") + parser.add_argument("--vf-coef", type=float, default=0.5, + help="coefficient of the value function") + parser.add_argument("--max-grad-norm", type=float, default=0.5, + help="the maximum norm for the gradient clipping") + parser.add_argument("--target-kl", type=float, default=None, + help="the target KL divergence threshold") + + # Procgen arguments + parser.add_argument("--full-dist", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Train on full distribution of levels.") + + # Curriculum arguments + parser.add_argument("--curriculum", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will use curriculum learning") + parser.add_argument("--curriculum-method", type=str, default="plr", + help="curriculum method to use") + parser.add_argument("--num-eval-episodes", type=int, default=10, + help="the number of episodes to evaluate the agent on after each policy update.") + + args = parser.parse_args() + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + # fmt: on + return args + + +PROCGEN_RETURN_BOUNDS = { + "coinrun": (5, 10), + "starpilot": (2.5, 64), + "caveflyer": (3.5, 12), + "dodgeball": (1.5, 19), + "fruitbot": (-1.5, 32.4), + "chaser": (0.5, 13), + "miner": (1.5, 13), + "jumper": (3, 10), + "leaper": (3, 10), + "maze": (5, 10), + "bigfish": (1, 40), + "heist": (3.5, 10), + "climber": (2, 12.6), + "plunder": (4.5, 30), + "ninja": (3.5, 10), + "bossfight": (0.5, 13), +} + + +def make_env(env_id, seed, curriculum=None, start_level=0, num_levels=1): + def thunk(): + env = openai_gym.make(f"procgen-{env_id}-v0", distribution_mode="easy", start_level=start_level, num_levels=num_levels) + env = GymV21CompatibilityV0(env=env) + if curriculum is not None: + env = ProcgenTaskWrapper(env, env_id, seed=seed) + env = MultiProcessingSyncWrapper( + env, + curriculum.get_components(), + update_on_step=curriculum.requires_step_updates, + task_space=env.task_space, + ) + return env + return thunk + + +def wrap_vecenv(vecenv): + vecenv.is_vector_env = True + vecenv = VecMonitor(venv=vecenv, filename=None, keep_buf=100) + vecenv = VecNormalize(venv=vecenv, ob=False, ret=True) + return vecenv + + +def full_level_replay_evaluate( + env_name, + policy, + num_episodes, + device, + num_levels=1 # Not used +): + policy.eval() + + eval_envs = ProcgenEnv( + num_envs=args.num_eval_episodes, env_name=env_name, num_levels=1, start_level=0, distribution_mode="easy", paint_vel_info=False + ) + eval_envs = VecExtractDictObs(eval_envs, "rgb") + eval_envs = wrap_vecenv(eval_envs) + + # Seed environments + seeds = [int.from_bytes(os.urandom(3), byteorder="little") for _ in range(num_episodes)] + for i, seed in enumerate(seeds): + eval_envs.seed(seed, i) + + eval_obs, _ = eval_envs.reset() + eval_episode_rewards = [-1] * num_episodes + + while -1 in eval_episode_rewards: + with torch.no_grad(): + eval_action, _, _, _ = policy.get_action_and_value(torch.Tensor(eval_obs).to(device), deterministic=False) + + eval_obs, _, truncs, terms, infos = eval_envs.step(eval_action.cpu().numpy()) + for i, info in enumerate(infos): + if 'episode' in info.keys() and eval_episode_rewards[i] == -1: + eval_episode_rewards[i] = info['episode']['r'] + + mean_returns = np.mean(eval_episode_rewards) + stddev_returns = np.std(eval_episode_rewards) + env_min, env_max = PROCGEN_RETURN_BOUNDS[args.env_id] + normalized_mean_returns = (mean_returns - env_min) / (env_max - env_min) + policy.train() + return mean_returns, stddev_returns, normalized_mean_returns + + +def level_replay_evaluate( + env_name, + policy, + num_episodes, + device, + num_levels=0 +): + policy.eval() + + eval_envs = ProcgenEnv( + num_envs=args.num_eval_episodes, env_name=env_name, num_levels=num_levels, start_level=0, distribution_mode="easy", paint_vel_info=False + ) + eval_envs = VecExtractDictObs(eval_envs, "rgb") + eval_envs = wrap_vecenv(eval_envs) + eval_obs, _ = eval_envs.reset() + eval_episode_rewards = [-1] * num_episodes + + while -1 in eval_episode_rewards: + with torch.no_grad(): + eval_action, _, _, _ = policy.get_action_and_value(torch.Tensor(eval_obs).to(device), deterministic=False) + + eval_obs, _, truncs, terms, infos = eval_envs.step(eval_action.cpu().numpy()) + for i, info in enumerate(infos): + if 'episode' in info.keys() and eval_episode_rewards[i] == -1: + eval_episode_rewards[i] = info['episode']['r'] + + # print(eval_episode_rewards) + mean_returns = np.mean(eval_episode_rewards) + stddev_returns = np.std(eval_episode_rewards) + env_min, env_max = PROCGEN_RETURN_BOUNDS[args.env_id] + normalized_mean_returns = (mean_returns - env_min) / (env_max - env_min) + policy.train() + return mean_returns, stddev_returns, normalized_mean_returns + + +def make_value_fn(): + def get_value(obs): + obs = np.array(obs) + with torch.no_grad(): + return agent.get_value(torch.Tensor(obs).to(device)) + return get_value + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + dir=args.logging_dir + ) + # wandb.run.log_code("./syllabus/examples") + + writer = SummaryWriter(os.path.join(args.logging_dir, "./runs/{run_name}")) + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + print("Device:", device) + + # Curriculum setup + curriculum = None + if args.curriculum: + sample_env = openai_gym.make(f"procgen-{args.env_id}-v0") + sample_env = GymV21CompatibilityV0(env=sample_env) + sample_env = ProcgenTaskWrapper(sample_env, args.env_id, seed=args.seed) + + # Intialize Curriculum Method + if args.curriculum_method == "plr": + print("Using prioritized level replay.") + curriculum = PrioritizedLevelReplay( + sample_env.task_space, + sample_env.observation_space, + num_steps=args.num_steps, + num_processes=args.num_envs, + gamma=args.gamma, + gae_lambda=args.gae_lambda, + task_sampler_kwargs_dict={"strategy": "value_l1"}, + get_value=make_value_fn(), + ) + elif args.curriculum_method == "dr": + print("Using domain randomization.") + curriculum = DomainRandomization(sample_env.task_space) + elif args.curriculum_method == "lp": + print("Using learning progress.") + curriculum = LearningProgressCurriculum(sample_env.task_space) + elif args.curriculum_method == "sq": + print("Using sequential curriculum.") + curricula = [] + stopping = [] + for i in range(199): + curricula.append(i + 1) + stopping.append("steps>=50000") + curricula.append(list(range(i + 1))) + stopping.append("steps>=50000") + curriculum = SequentialCurriculum(curricula, stopping[:-1], sample_env.task_space) + else: + raise ValueError(f"Unknown curriculum method {args.curriculum_method}") + curriculum = make_multiprocessing_curriculum(curriculum) + del sample_env + + # env setup + print("Creating env") + envs = gym.vector.AsyncVectorEnv( + [ + make_env( + args.env_id, + args.seed + i, + curriculum=curriculum if args.curriculum else None, + num_levels=1 if args.curriculum else 0 + ) + for i in range(args.num_envs) + ] + ) + envs = wrap_vecenv(envs) + + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + print("Creating agent") + agent = ProcgenAgent( + envs.single_observation_space.shape, + envs.single_action_space.n, + arch="large", + base_kwargs={'recurrent': False, 'hidden_size': 256} + ).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs, _ = envs.reset() + next_obs = torch.Tensor(next_obs).to(device) + next_done = torch.zeros(args.num_envs).to(device) + num_updates = args.total_timesteps // args.batch_size + episode_rewards = deque(maxlen=10) + completed_episodes = 0 + + for update in range(1, num_updates + 1): + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (update - 1.0) / num_updates + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += 1 * args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value = agent.get_action_and_value(next_obs) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, term, trunc, info = envs.step(action.cpu().numpy()) + done = np.logical_or(term, trunc) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) + completed_episodes += sum(done) + + for item in info: + if "episode" in item.keys(): + episode_rewards.append(item['episode']['r']) + print(f"global_step={global_step}, episodic_return={item['episode']['r']}") + writer.add_scalar("charts/episodic_return", item["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step) + if curriculum is not None: + curriculum.log_metrics(writer, global_step) + break + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + if args.gae: + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values + else: + returns = torch.zeros_like(rewards).to(device) + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + next_return = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + next_return = returns[t + 1] + returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return + advantages = returns - values + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None: + if approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # Evaluate agent + mean_eval_returns, stddev_eval_returns, normalized_mean_eval_returns = level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=0 + ) + full_mean_eval_returns, full_stddev_eval_returns, full_normalized_mean_eval_returns = full_level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=0 + ) + mean_train_returns, stddev_train_returns, normalized_mean_train_returns = level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=200 + ) + full_mean_train_returns, full_stddev_train_returns, full_normalized_mean_train_returns = full_level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=200 + ) + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("charts/episode_returns", np.mean(episode_rewards), global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + writer.add_scalar("test_eval/mean_episode_return", mean_eval_returns, global_step) + writer.add_scalar("test_eval/normalized_mean_eval_return", normalized_mean_eval_returns, global_step) + writer.add_scalar("test_eval/stddev_eval_return", stddev_eval_returns, global_step) + writer.add_scalar("test_eval/full_mean_episode_return", full_mean_eval_returns, global_step) + writer.add_scalar("test_eval/full_normalized_mean_eval_return", full_normalized_mean_eval_returns, global_step) + writer.add_scalar("test_eval/full_stddev_eval_return", full_stddev_eval_returns, global_step) + + writer.add_scalar("train_eval/mean_episode_return", mean_train_returns, global_step) + writer.add_scalar("train_eval/normalized_mean_train_return", normalized_mean_train_returns, global_step) + writer.add_scalar("train_eval/stddev_train_return", stddev_train_returns, global_step) + writer.add_scalar("train_eval/full_mean_episode_return", full_mean_train_returns, global_step) + writer.add_scalar("train_eval/full_normalized_mean_train_return", full_normalized_mean_train_returns, global_step) + writer.add_scalar("train_eval/full_stddev_train_return", full_stddev_train_returns, global_step) + + writer.add_scalar("curriculum/completed_episodes", completed_episodes, step) + + envs.close() + writer.close() diff --git a/examples/utils/__init__.py b/examples/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/utils/vecenv.py b/examples/utils/vecenv.py new file mode 100644 index 00000000..af3b187a --- /dev/null +++ b/examples/utils/vecenv.py @@ -0,0 +1,308 @@ +import time +from collections import deque + +import numpy as np + + +class VecEnv: + """ + An abstract asynchronous, vectorized environment. + Used to batch data from multiple copies of an environment, so that + each observation becomes an batch of observations, and expected action is a batch of actions to + be applied per-environment. + """ + closed = False + viewer = None + + metadata = { + 'render.modes': ['human', 'rgb_array'] + } + + def __init__(self, num_envs, observation_space, action_space): + self.num_envs = num_envs + self.observation_space = observation_space + self.action_space = action_space + + def reset(self): + """ + Reset all the environments and return an array of + observations, or a dict of observation arrays. + + If step_async is still doing work, that work will + be cancelled and step_wait() should not be called + until step_async() is invoked again. + """ + pass + + def step_async(self, actions): + """ + Tell all the environments to start taking a step + with the given actions. + Call step_wait() to get the results of the step. + + You should not call this if a step_async run is + already pending. + """ + pass + + def step_wait(self): + """ + Wait for the step taken with step_async(). + + Returns (obs, rews, dones, infos): + - obs: an array of observations, or a dict of + arrays of observations. + - rews: an array of rewards + - dones: an array of "episode done" booleans + - infos: a sequence of info objects + """ + pass + + def close_extras(self): + """ + Clean up the extra resources, beyond what's in this base class. + Only runs when not self.closed. + """ + pass + + def close(self): + if self.closed: + return + if self.viewer is not None: + self.viewer.close() + self.close_extras() + self.closed = True + + def step(self, actions): + """ + Step the environments synchronously. + + This is available for backwards compatibility. + """ + self.step_async(actions) + return self.step_wait() + + def step_env(self, actions, reset_random=False): + if reset_random: + self.step_env_reset_random_async(actions) + else: + self.step_env_async(actions) + return self.step_wait() + + def render(self, mode='human'): + raise NotImplementedError + + def get_images(self): + """ + Return RGB images from each environment + """ + raise NotImplementedError + + @property + def unwrapped(self): + if isinstance(self, VecEnvWrapper): + return self.venv.unwrapped + else: + return self + + def get_viewer(self): + if self.viewer is None: + from gym.envs.classic_control import rendering + self.viewer = rendering.SimpleImageViewer() + return self.viewer + + +class VecEnvWrapper(VecEnv): + """ + An environment wrapper that applies to an entire batch + of environments at once. + """ + + def __init__(self, venv, observation_space=None, action_space=None): + self.venv = venv + VecEnv.__init__(self, num_envs=venv.num_envs, + observation_space=observation_space or venv.observation_space, + action_space=action_space or venv.action_space) + + def step_async(self, actions): + self.venv.step_async(actions) + + def reset(self): + pass + + def step_wait(self): + pass + + def close(self): + return self.venv.close() + + def render(self, mode='human'): + return self.venv.render(mode=mode) + + def get_images(self): + return self.venv.get_images() + + def __getattr__(self, name): + if name.startswith('_'): + raise AttributeError("attempted to get missing private attribute '{}'".format(name)) + return getattr(self.venv, name) + + +class VecEnvObservationWrapper(VecEnvWrapper): + def process(self, obs): + pass + + def reset(self): + outputs = self.venv.reset() + if len(outputs) == 2: + obs, infos = outputs + else: + obs, infos = outputs, {} + return self.process(obs), infos + + def step_wait(self): + env_outputs = self.venv.step_wait() + if len(env_outputs) == 4: + obs, rews, terms, infos = env_outputs + truncs = np.zeros_like(terms) + else: + obs, rews, terms, truncs, infos = env_outputs + return self.process(obs), rews, terms, truncs, infos + + +class VecExtractDictObs(VecEnvObservationWrapper): + def __init__(self, venv, key): + self.key = key + super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key]) + + def process(self, obs): + return obs[self.key] + + +class VecNormalize(VecEnvWrapper): + """ + A vectorized wrapper that normalizes the observations + and returns from an environment. + """ + + def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8, use_tf=False): + VecEnvWrapper.__init__(self, venv) + self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None + self.ret_rms = RunningMeanStd(shape=()) if ret else None + self.clipob = clipob + self.cliprew = cliprew + self.ret = np.zeros(self.num_envs) + self.gamma = gamma + self.epsilon = epsilon + + def step_wait(self): + obs, rews, terms, truncs, infos = self.venv.step_wait() + news = np.logical_or(terms, truncs) + self.ret = self.ret * self.gamma + rews + obs = self._obfilt(obs) + if self.ret_rms: + self.ret_rms.update(self.ret) + rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew) + self.ret[news] = 0. + return obs, rews, terms, truncs, infos + + def _obfilt(self, obs): + if self.ob_rms: + self.ob_rms.update(obs) + obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob) + return obs + else: + return obs + + def reset(self, seed=None): + self.ret = np.zeros(self.num_envs) + if seed is not None: + obs, infos = self.venv.reset(seed=seed) + else: + obs, infos = self.venv.reset() + return self._obfilt(obs), infos + + +class VecMonitor(VecEnvWrapper): + def __init__(self, venv, filename=None, keep_buf=0, info_keywords=()): + VecEnvWrapper.__init__(self, venv) + self.eprets = None + self.eplens = None + self.epcount = 0 + self.tstart = time.time() + self.results_writer = None + self.info_keywords = info_keywords + self.keep_buf = keep_buf + if self.keep_buf: + self.epret_buf = deque([], maxlen=keep_buf) + self.eplen_buf = deque([], maxlen=keep_buf) + + def reset(self, seed=None): + if seed is not None: + obs, infos = self.venv.reset(seed=seed) + else: + obs, infos = self.venv.reset() + self.eprets = np.zeros(self.num_envs, 'f') + self.eplens = np.zeros(self.num_envs, 'i') + return obs, infos + + def step_wait(self): + obs, rews, terms, truncs, infos = self.venv.step_wait() + dones = np.logical_or(terms, truncs) + self.eprets += rews + self.eplens += 1 + # Convert dict of lists to list of dicts + if isinstance(infos, dict): + infos = [dict(zip(infos, t)) for t in zip(*infos.values())] + newinfos = list(infos[:]) + for i in range(len(dones)): + if dones[i]: + info = infos[i].copy() + ret = self.eprets[i] + eplen = self.eplens[i] + epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)} + for k in self.info_keywords: + epinfo[k] = info[k] + info['episode'] = epinfo + if self.keep_buf: + self.epret_buf.append(ret) + self.eplen_buf.append(eplen) + self.epcount += 1 + self.eprets[i] = 0 + self.eplens[i] = 0 + if self.results_writer: + self.results_writer.write_row(epinfo) + newinfos[i] = info + return obs, rews, terms, truncs, newinfos + + +class RunningMeanStd(): + # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + def __init__(self, epsilon=1e-4, shape=()): + self.mean = np.zeros(shape, 'float64') + self.var = np.ones(shape, 'float64') + self.count = epsilon + + def update(self, x): + batch_mean = np.mean(x, axis=0) + batch_var = np.var(x, axis=0) + batch_count = x.shape[0] + self.update_from_moments(batch_mean, batch_var, batch_count) + + def update_from_moments(self, batch_mean, batch_var, batch_count): + self.mean, self.var, self.count = update_mean_var_count_from_moments( + self.mean, self.var, self.count, batch_mean, batch_var, batch_count) + + +def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count): + delta = batch_mean - mean + tot_count = count + batch_count + + new_mean = mean + delta * batch_count / tot_count + m_a = var * count + m_b = batch_var * batch_count + M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count + new_var = M2 / tot_count + new_count = tot_count + + return new_mean, new_var, new_count diff --git a/examples/utils/vtrace.py b/examples/utils/vtrace.py new file mode 100644 index 00000000..7b02c574 --- /dev/null +++ b/examples/utils/vtrace.py @@ -0,0 +1,138 @@ +# This file taken from +# https://github.com/deepmind/scalable_agent/blob/ +# cd66d00914d56c8ba2f0615d9cdeefcb169a8d70/vtrace.py +# and modified. + +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions to compute V-trace off-policy actor critic targets. + +For details and theory see: + +"IMPALA: Scalable Distributed Deep-RL with +Importance Weighted Actor-Learner Architectures" +by Espeholt, Soyer, Munos et al. + +See https://arxiv.org/abs/1802.01561 for the full paper. +""" + +import collections + +import torch +import torch.nn.functional as F + +VTraceFromLogitsReturns = collections.namedtuple( + "VTraceFromLogitsReturns", + [ + "vs", + "pg_advantages", + "log_rhos", + "behavior_action_log_probs", + "target_action_log_probs", + ], +) + +VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages") + + +def action_log_probs(policy_logits, actions): + return -F.nll_loss( + F.log_softmax(torch.flatten(policy_logits, 0, -2), dim=-1), + torch.flatten(actions), + reduction="none", + ).view_as(actions) + + +def from_logits( + behavior_policy_logits, + target_policy_logits, + actions, + discounts, + rewards, + values, + bootstrap_value, + clip_rho_threshold=1.0, + clip_pg_rho_threshold=1.0, +): + """V-trace for softmax policies.""" + + target_action_log_probs = action_log_probs(target_policy_logits, actions) + behavior_action_log_probs = action_log_probs(behavior_policy_logits, actions) + log_rhos = target_action_log_probs - behavior_action_log_probs + vtrace_returns = from_importance_weights( + log_rhos=log_rhos, + discounts=discounts, + rewards=rewards, + values=values, + bootstrap_value=bootstrap_value, + clip_rho_threshold=clip_rho_threshold, + clip_pg_rho_threshold=clip_pg_rho_threshold, + ) + return VTraceFromLogitsReturns( + log_rhos=log_rhos, + behavior_action_log_probs=behavior_action_log_probs, + target_action_log_probs=target_action_log_probs, + **vtrace_returns._asdict(), + ) + + +@torch.no_grad() +def from_importance_weights( + log_rhos, + discounts, + rewards, + values, + bootstrap_value, + clip_rho_threshold=1.0, + clip_pg_rho_threshold=1.0, +): + """V-trace from log importance weights.""" + with torch.no_grad(): + rhos = torch.exp(log_rhos) + if clip_rho_threshold is not None: + clipped_rhos = torch.clamp(rhos, max=clip_rho_threshold) + else: + clipped_rhos = rhos + + cs = torch.clamp(rhos, max=1.0) + # Append bootstrapped value to get [v1, ..., v_t+1] + values_t_plus_1 = torch.cat( + [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0 + ) + deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values) + + acc = torch.zeros_like(bootstrap_value) + result = [] + for t in range(discounts.shape[0] - 1, -1, -1): + acc = deltas[t] + discounts[t] * cs[t] * acc + result.append(acc) + result.reverse() + vs_minus_v_xs = torch.stack(result) + + # Add V(x_s) to get v_s. + vs = torch.add(vs_minus_v_xs, values) + + # Advantage for policy gradient. + broadcasted_bootstrap_values = torch.ones_like(vs[0]) * bootstrap_value + vs_t_plus_1 = torch.cat( + [vs[1:], broadcasted_bootstrap_values.unsqueeze(0)], dim=0 + ) + if clip_pg_rho_threshold is not None: + clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold) + else: + clipped_pg_rhos = rhos + pg_advantages = clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values) + + # Make sure no gradients backpropagated through the returned values. + return VTraceReturns(vs=vs, pg_advantages=pg_advantages) diff --git a/syllabus/core/__init__.py b/syllabus/core/__init__.py index 07322779..08f7a709 100644 --- a/syllabus/core/__init__.py +++ b/syllabus/core/__init__.py @@ -13,3 +13,4 @@ from .environment_sync_wrapper import MultiProcessingSyncWrapper, RaySyncWrapper # , PettingZooMultiProcessingSyncWrapper from .multivariate_curriculum_wrapper import MultitaskWrapper +from .stat_recorder import StatRecorder diff --git a/syllabus/core/curriculum_base.py b/syllabus/core/curriculum_base.py index 4ca9aeb0..3f31ab95 100644 --- a/syllabus/core/curriculum_base.py +++ b/syllabus/core/curriculum_base.py @@ -6,6 +6,7 @@ from gymnasium.spaces import Dict from syllabus.task_space import TaskSpace +from .stat_recorder import StatRecorder # TODO: Move non-generic logic to Uniform class. Allow subclasses to call super for generic error handling @@ -28,6 +29,7 @@ def __init__(self, task_space: TaskSpace, random_start_tasks: int = 0, task_name self.completed_tasks = 0 self.task_names = task_names self.n_updates = 0 + self.stat_recorder = StatRecorder(self.task_space) if self.num_tasks == 0: warnings.warn("Task space is empty. This will cause errors during sampling if no tasks are added.") @@ -74,6 +76,7 @@ def update_task_progress(self, task: typing.Any, progress: Tuple[float, bool], e :param task: Task for which progress is being updated. :param progress: Progress toward completion or success rate of the given task. 1.0 or True typically indicates a complete task. """ + self.completed_tasks += 1 def update_on_step(self, task: typing.Any, obs: typing.Any, rew: float, term: bool, trunc: bool, info: dict, env_id: int = None) -> None: @@ -96,19 +99,32 @@ def update_on_step_batch(self, step_results: List[typing.Tuple[Any, Any, int, in :param step_results: List of step results """ - tasks, obs, rews, terms, truncs, infos = tuple(step_results) + + #obs, rews, terms, truncs, infos = tuple(step_results) + obs = [r[0] for r in step_results] + rews = [r[1] for r in step_results] + terms = [r[2] for r in step_results] + truncs = [r[3] for r in step_results] + infos = [r[4] for r in step_results] + for i in range(len(obs)): self.update_on_step(tasks[i], obs[i], rews[i], terms[i], truncs[i], infos[i], env_id=env_id) - def update_on_episode(self, episode_return: float, episode_length: int, episode_task: Any, env_id: int = None) -> None: + def update_on_episode(self, episode_return: float, episode_length: int, episode_task, env_id=None) -> None: """Update the curriculum with episode results from the environment. :param episode_return: Episodic return :param trajectory: trajectory of (s, a, r, s, ...), defaults to None :raises NotImplementedError: """ - # TODO: Add update_on_episode option similar to update-on_step - pass + self.stat_recorder.record(episode_return, episode_length, episode_task, env_id) + #raise NotImplementedError("Not yet implemented.") + + def normalize(self, reward, task): + """ + Normalize reward by task. + """ + return self.stat_recorder.normalize(reward, task) def update_on_demand(self, metrics: Dict): """Update the curriculum with arbitrary inputs. diff --git a/syllabus/core/curriculum_sync_wrapper.py b/syllabus/core/curriculum_sync_wrapper.py index f9866438..dd86a858 100644 --- a/syllabus/core/curriculum_sync_wrapper.py +++ b/syllabus/core/curriculum_sync_wrapper.py @@ -51,10 +51,14 @@ def update_on_step(self, task, step, reward, term, trunc): def log_metrics(self, writer, step=None): self.curriculum.log_metrics(writer, step=step) + self.curriculum.stat_recorder.log_metrics(writer, step=step) def update_on_step_batch(self, step_results): self.curriculum.update_on_step_batch(step_results) + def update_on_episode(self, episode_return, episode_length, episode_task, env_id=None): + self.curriculum.update_on_episode(episode_return, episode_length, episode_task, env_id) + def update(self, metrics): self.curriculum.update(metrics) @@ -63,6 +67,9 @@ def update_batch(self, metrics): def add_task(self, task): self.curriculum.add_task(task) + + def normalize(self, rewards, task): + return self.curriculum.normalize(rewards, task) class MultiProcessingComponents: @@ -234,6 +241,9 @@ def _update_queues(self): else: time.sleep(0.01) + def update_on_episode(self, episode_return, episode_length, episode_task, env_id=None): + super().update_on_episode(episode_return, episode_length, episode_task, env_id) + def log_metrics(self, writer, step=None): super().log_metrics(writer, step=step) if self.get_components()._debug: @@ -246,6 +256,9 @@ def add_task(self, task): def get_components(self): return self._components + + def normalize(self, rewards, task): + return super().normalize(rewards, task) def remote_call(func): diff --git a/syllabus/core/environment_sync_wrapper.py b/syllabus/core/environment_sync_wrapper.py index 6edee7cc..e9a2b355 100644 --- a/syllabus/core/environment_sync_wrapper.py +++ b/syllabus/core/environment_sync_wrapper.py @@ -80,6 +80,8 @@ def reset(self, *args, **kwargs): added_tasks = message["added_tasks"] for add_task in added_tasks: self.env.add_task(add_task) + obs, info = self.env.reset(*args, new_task=next_task, **kwargs) + info["task"] = self.task_space.encode(self.get_task()) return self.env.reset(*args, new_task=next_task, **kwargs) def step(self, action): @@ -121,7 +123,9 @@ def step(self, action): "request_sample": True } self.components.put_update([task_update, episode_update]) - + + info["task"] = self.task_space.encode(self.get_task()) + return obs, rew, term, trunc, info def _package_step_updates(self): diff --git a/syllabus/core/stat_recorder.py b/syllabus/core/stat_recorder.py new file mode 100644 index 00000000..d4c025b9 --- /dev/null +++ b/syllabus/core/stat_recorder.py @@ -0,0 +1,113 @@ +import os +import json +import warnings +import numpy as np +from syllabus.task_space import TaskSpace +from gymnasium.spaces import Discrete +from collections import deque, defaultdict + +class StatRecorder: + """ + Individual statistics tracking for each task. + """ + + def __init__(self, task_space: TaskSpace, calc_past_n=None): + """Initialize the StatRecorder""" + + self.task_space = task_space + self.calc_past_n = calc_past_n + + assert isinstance(self.task_space, TaskSpace), f"task_space must be a TaskSpace object. Got {type(task_space)} instead." + assert isinstance(self.task_space.gym_space, Discrete), f"Only Discrete task spaces are supported. Got {type(task_space.gym_space)}" + + self.tasks = self.task_space.get_tasks() + self.num_tasks = self.task_space.num_tasks + + if self.calc_past_n is not None: + self.episode_returns = {task: deque(maxlen=calc_past_n) for task in self.tasks} + self.episode_lengths = {task: deque(maxlen=calc_past_n) for task in self.tasks} + self.env_ids = {task: deque(maxlen=calc_past_n) for task in self.tasks} + else: + self.num_past_episodes = {task: 0 for task in self.tasks} + + self.stats = {task: defaultdict(float) for task in self.tasks} + + def record(self, episode_return: float, episode_length: int, episode_task, env_id=None): + """ + Record the length and return of an episode for a given task. + + :param episode_length: Length of the episode, i.e. the total number of steps taken during the episode + :param episodic_return: Total return for the episode + :param episode_task: Identifier for the task + """ + + if episode_task in self.tasks: + if self.calc_past_n is not None: + self.episode_returns[episode_task].append(episode_return) + self.episode_lengths[episode_task].append(episode_length) + self.env_ids[episode_task].append(env_id) + + self.stats[episode_task]['mean_r'] = np.mean(self.episode_returns[episode_task]) + self.stats[episode_task]['var_r'] = np.var(self.episode_returns[episode_task]) + self.stats[episode_task]['mean_l'] = np.mean(self.episode_lengths[episode_task]) + self.stats[episode_task]['var_l'] = np.var(self.episode_lengths[episode_task]) + else: + n_past = self.num_past_episodes[episode_task] + self.num_past_episodes[episode_task] += 1 + + self.stats[episode_task]['mean_r'] = (self.stats[episode_task]['mean_r'] * n_past + episode_return) / (n_past + 1) + self.stats[episode_task]['mean_r_squared'] = (self.stats[episode_task]['mean_r_squared'] * n_past + episode_return ** 2) / (n_past + 1) + self.stats[episode_task]['var_r'] = self.stats[episode_task]['mean_r_squared'] - self.stats[episode_task]['mean_r'] ** 2 + + self.stats[episode_task]['mean_l'] = (self.stats[episode_task]['mean_l'] * n_past + episode_length) / (n_past + 1) + self.stats[episode_task]['mean_l_squared'] = (self.stats[episode_task]['mean_l_squared'] * n_past + episode_length ** 2) / (n_past + 1) + self.stats[episode_task]['var_l'] = self.stats[episode_task]['mean_l_squared'] - self.stats[episode_task]['mean_l'] ** 2 + else: + raise ValueError("Unknown task") + + def log_metrics(self, writer, step=None, log_full_dist=False): + """Log the statistics of the first 5 tasks to the provided tensorboard writer. + + :param writer: Tensorboard summary writer. + """ + try: + import wandb + tasks_to_log = self.tasks + if len(self.tasks) > 5 and not log_full_dist: + warnings.warn("Only logging stats for 5 tasks.") + tasks_to_log = self.tasks[:5] + for idx in tasks_to_log: + if self.stats[idx]: + writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_mean", self.stats[idx]['mean_r'], step) + writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_var", self.stats[idx]['var_r'], step) + writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_mean", self.stats[idx]['mean_l'], step) + writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_var", self.stats[idx]['var_l'], step) + except ImportError: + warnings.warn("Wandb is not installed. Skipping logging.") + except wandb.errors.Error: + # No need to crash over logging :) + warnings.warn("Failed to log curriculum stats to wandb.") + + def normalize(self, reward, task): + """ + Normalize reward by task. + """ + task_stats = self.stats[task] + reward_mean = task_stats['mean_r'] + reward_std = np.sqrt(task_stats['var_r']) + normalized_reward = deque(maxlen=reward.maxlen) + for r in reward: + normalized_reward.append((r - reward_mean) / max(0.01, reward_std)) + return normalized_reward + + def save_statistics(self, output_path): + """ + Write task-specific statistics to file. + """ + def convert_numpy(obj): + if isinstance(obj, np.generic): + return obj.item() # Use .item() to convert numpy types to native Python types + raise TypeError + stats = json.dumps(self.stats, default=convert_numpy) + with open(os.path.join(output_path, 'task_specific_stats.json'), "w") as file: + file.write(stats) \ No newline at end of file diff --git a/syllabus/curricula/plr/plr_wrapper.py b/syllabus/curricula/plr/plr_wrapper.py index 9c808ddc..f4dae1f4 100644 --- a/syllabus/curricula/plr/plr_wrapper.py +++ b/syllabus/curricula/plr/plr_wrapper.py @@ -264,6 +264,25 @@ def update_on_step_batch( task=tasks, ) + #def update_on_episode(self, episode_return: float, episode_task, env_id: int = None) -> None: + #""" + ##Update the curriculum with episode results from the environment. + #"" + #raise NotImplementedError( + #"PrioritizedLevelReplay does not support the episode updates. Use on_demand from the learner process." + #) + + + def update_task_progress(self, task: Any, success_prob: float, env_id: int = None) -> None: + """ + Update the curriculum with a task and its success probability upon + success or failure. + """ + assert env_id is not None, "env_id must be provided for PLR updates." + self._rollouts.insert_at_index( + env_id, + task=task, + ) # Update task sampler if env_id in self._rollouts.ready_buffers: self._update_sampler(env_id) diff --git a/syllabus/examples/task_wrappers/procgen_task_wrapper.py b/syllabus/examples/task_wrappers/procgen_task_wrapper.py index 2296fd58..497c9eca 100644 --- a/syllabus/examples/task_wrappers/procgen_task_wrapper.py +++ b/syllabus/examples/task_wrappers/procgen_task_wrapper.py @@ -40,6 +40,7 @@ def __init__(self, env: gym.Env, env_id, seed=0): def seed(self, seed): self.env.gym_env.unwrapped._venv.seed(int(seed), 0) + #pass def reset(self, new_task=None, **kwargs): """ diff --git a/syllabus/examples/training_scripts/cleanrl_procgen_centralplr.py b/syllabus/examples/training_scripts/cleanrl_procgen_centralplr.py index b848d693..704dfcae 100644 --- a/syllabus/examples/training_scripts/cleanrl_procgen_centralplr.py +++ b/syllabus/examples/training_scripts/cleanrl_procgen_centralplr.py @@ -150,46 +150,6 @@ def wrap_vecenv(vecenv): return vecenv -def full_level_replay_evaluate( - env_name, - policy, - num_episodes, - device, - num_levels=1 # Not used -): - policy.eval() - - eval_envs = ProcgenEnv( - num_envs=args.num_eval_episodes, env_name=env_name, num_levels=1, start_level=0, distribution_mode="easy", paint_vel_info=False - ) - eval_envs = VecExtractDictObs(eval_envs, "rgb") - eval_envs = wrap_vecenv(eval_envs) - - # Seed environments - seeds = [int.from_bytes(os.urandom(3), byteorder="little") for _ in range(num_episodes)] - for i, seed in enumerate(seeds): - eval_envs.seed(seed, i) - - eval_obs, _ = eval_envs.reset() - eval_episode_rewards = [-1] * num_episodes - - while -1 in eval_episode_rewards: - with torch.no_grad(): - eval_action, _, _, _ = policy.get_action_and_value(torch.Tensor(eval_obs).to(device), deterministic=False) - - eval_obs, _, truncs, terms, infos = eval_envs.step(eval_action.cpu().numpy()) - for i, info in enumerate(infos): - if 'episode' in info.keys() and eval_episode_rewards[i] == -1: - eval_episode_rewards[i] = info['episode']['r'] - - mean_returns = np.mean(eval_episode_rewards) - stddev_returns = np.std(eval_episode_rewards) - env_min, env_max = PROCGEN_RETURN_BOUNDS[args.env_id] - normalized_mean_returns = (mean_returns - env_min) / (env_max - env_min) - policy.train() - return mean_returns, stddev_returns, normalized_mean_returns - - def level_replay_evaluate( env_name, policy, @@ -493,15 +453,9 @@ def level_replay_evaluate( mean_eval_returns, stddev_eval_returns, normalized_mean_eval_returns = level_replay_evaluate( args.env_id, agent, args.num_eval_episodes, device, num_levels=0 ) - full_mean_eval_returns, full_stddev_eval_returns, full_normalized_mean_eval_returns = full_level_replay_evaluate( - args.env_id, agent, args.num_eval_episodes, device, num_levels=0 - ) mean_train_returns, stddev_train_returns, normalized_mean_train_returns = level_replay_evaluate( args.env_id, agent, args.num_eval_episodes, device, num_levels=200 ) - full_mean_train_returns, full_stddev_train_returns, full_normalized_mean_train_returns = full_level_replay_evaluate( - args.env_id, agent, args.num_eval_episodes, device, num_levels=200 - ) # TRY NOT TO MODIFY: record rewards for plotting purposes writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) @@ -519,16 +473,10 @@ def level_replay_evaluate( writer.add_scalar("test_eval/mean_episode_return", mean_eval_returns, global_step) writer.add_scalar("test_eval/normalized_mean_eval_return", normalized_mean_eval_returns, global_step) writer.add_scalar("test_eval/stddev_eval_return", stddev_eval_returns, global_step) - writer.add_scalar("test_eval/full_mean_episode_return", full_mean_eval_returns, global_step) - writer.add_scalar("test_eval/full_normalized_mean_eval_return", full_normalized_mean_eval_returns, global_step) - writer.add_scalar("test_eval/full_stddev_eval_return", full_stddev_eval_returns, global_step) writer.add_scalar("train_eval/mean_episode_return", mean_train_returns, global_step) writer.add_scalar("train_eval/normalized_mean_train_return", normalized_mean_train_returns, global_step) writer.add_scalar("train_eval/stddev_train_return", stddev_train_returns, global_step) - writer.add_scalar("train_eval/full_mean_episode_return", full_mean_train_returns, global_step) - writer.add_scalar("train_eval/full_normalized_mean_train_return", full_normalized_mean_train_returns, global_step) - writer.add_scalar("train_eval/full_stddev_train_return", full_stddev_train_returns, global_step) writer.add_scalar("curriculum/completed_episodes", completed_episodes, step) diff --git a/syllabus/examples/training_scripts/cleanrl_procgen_plr.py b/syllabus/examples/training_scripts/cleanrl_procgen_plr.py index dabcd500..536ecd87 100644 --- a/syllabus/examples/training_scripts/cleanrl_procgen_plr.py +++ b/syllabus/examples/training_scripts/cleanrl_procgen_plr.py @@ -150,46 +150,6 @@ def wrap_vecenv(vecenv): return vecenv -def full_level_replay_evaluate( - env_name, - policy, - num_episodes, - device, - num_levels=1 # Not used -): - policy.eval() - - eval_envs = ProcgenEnv( - num_envs=args.num_eval_episodes, env_name=env_name, num_levels=1, start_level=0, distribution_mode="easy", paint_vel_info=False - ) - eval_envs = VecExtractDictObs(eval_envs, "rgb") - eval_envs = wrap_vecenv(eval_envs) - - # Seed environments - seeds = [int.from_bytes(os.urandom(3), byteorder="little") for _ in range(num_episodes)] - for i, seed in enumerate(seeds): - eval_envs.seed(seed, i) - - eval_obs, _ = eval_envs.reset() - eval_episode_rewards = [-1] * num_episodes - - while -1 in eval_episode_rewards: - with torch.no_grad(): - eval_action, _, _, _ = policy.get_action_and_value(torch.Tensor(eval_obs).to(device), deterministic=False) - - eval_obs, _, truncs, terms, infos = eval_envs.step(eval_action.cpu().numpy()) - for i, info in enumerate(infos): - if 'episode' in info.keys() and eval_episode_rewards[i] == -1: - eval_episode_rewards[i] = info['episode']['r'] - - mean_returns = np.mean(eval_episode_rewards) - stddev_returns = np.std(eval_episode_rewards) - env_min, env_max = PROCGEN_RETURN_BOUNDS[args.env_id] - normalized_mean_returns = (mean_returns - env_min) / (env_max - env_min) - policy.train() - return mean_returns, stddev_returns, normalized_mean_returns - - def level_replay_evaluate( env_name, policy, @@ -251,7 +211,7 @@ def get_value(obs): ) # wandb.run.log_code("./syllabus/examples") - writer = SummaryWriter(os.path.join(args.logging_dir, "./runs/{run_name}")) + writer = SummaryWriter(os.path.join(args.logging_dir, f"./runs/{run_name}")) writer.add_text( "hyperparameters", "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), @@ -296,11 +256,11 @@ def get_value(obs): print("Using sequential curriculum.") curricula = [] stopping = [] - for i in range(199): - curricula.append(i + 1) - stopping.append("steps>=50000") - curricula.append(list(range(i + 1))) - stopping.append("steps>=50000") + for i in range(0, 199, 10): + curricula.append(list(range(i, i+10))) + stopping.append("steps>=500000") + curricula.append(list(range(i + 10))) + stopping.append("steps>=500000") curriculum = SequentialCurriculum(curricula, stopping[:-1], sample_env.task_space) else: raise ValueError(f"Unknown curriculum method {args.curriculum_method}") @@ -309,7 +269,8 @@ def get_value(obs): # env setup print("Creating env") - envs = gym.vector.AsyncVectorEnv( + envs = gym.vector.SyncVectorEnv( + #envs = gym.vector.AsyncVectorEnv( [ make_env( args.env_id, @@ -322,6 +283,24 @@ def get_value(obs): ) envs = wrap_vecenv(envs) + test_eval_envs = gym.vector.SyncVectorEnv( + #test_eval_envs = gym.vector.AsyncVectorEnv( + [ + make_env(args.env_id, args.seed + i, num_levels=0) + for i in range(args.num_eval_episodes) + ] + ) + test_eval_envs = wrap_vecenv(test_eval_envs) + + train_eval_envs = gym.vector.SyncVectorEnv( + #train_eval_envs = gym.vector.AsyncVectorEnv( + [ + make_env(args.env_id, args.seed + i, num_levels=200) + for i in range(args.num_eval_episodes) + ] + ) + train_eval_envs = wrap_vecenv(train_eval_envs) + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" print("Creating agent") agent = ProcgenAgent( @@ -384,6 +363,12 @@ def get_value(obs): writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step) if curriculum is not None: curriculum.log_metrics(writer, global_step) + + # track return for individual tasks + idx = info.index(item) + multiprocessing_sync_wrapper_envs = envs.venv.venv.envs # extract envs of class MultiProcessingSyncWrapper from envs of class VecNormalize, which has access to task id + episode_task = multiprocessing_sync_wrapper_envs[idx]._latest_task + curriculum.update_on_episode(item["episode"]["r"], item["episode"]["l"], episode_task, args.env_id) break # bootstrap value if not done @@ -485,15 +470,9 @@ def get_value(obs): mean_eval_returns, stddev_eval_returns, normalized_mean_eval_returns = level_replay_evaluate( args.env_id, agent, args.num_eval_episodes, device, num_levels=0 ) - full_mean_eval_returns, full_stddev_eval_returns, full_normalized_mean_eval_returns = full_level_replay_evaluate( - args.env_id, agent, args.num_eval_episodes, device, num_levels=0 - ) mean_train_returns, stddev_train_returns, normalized_mean_train_returns = level_replay_evaluate( args.env_id, agent, args.num_eval_episodes, device, num_levels=200 ) - full_mean_train_returns, full_stddev_train_returns, full_normalized_mean_train_returns = full_level_replay_evaluate( - args.env_id, agent, args.num_eval_episodes, device, num_levels=200 - ) # TRY NOT TO MODIFY: record rewards for plotting purposes writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) @@ -511,16 +490,10 @@ def get_value(obs): writer.add_scalar("test_eval/mean_episode_return", mean_eval_returns, global_step) writer.add_scalar("test_eval/normalized_mean_eval_return", normalized_mean_eval_returns, global_step) writer.add_scalar("test_eval/stddev_eval_return", stddev_eval_returns, global_step) - writer.add_scalar("test_eval/full_mean_episode_return", full_mean_eval_returns, global_step) - writer.add_scalar("test_eval/full_normalized_mean_eval_return", full_normalized_mean_eval_returns, global_step) - writer.add_scalar("test_eval/full_stddev_eval_return", full_stddev_eval_returns, global_step) writer.add_scalar("train_eval/mean_episode_return", mean_train_returns, global_step) writer.add_scalar("train_eval/normalized_mean_train_return", normalized_mean_train_returns, global_step) writer.add_scalar("train_eval/stddev_train_return", stddev_train_returns, global_step) - writer.add_scalar("train_eval/full_mean_episode_return", full_mean_train_returns, global_step) - writer.add_scalar("train_eval/full_normalized_mean_train_return", full_normalized_mean_train_returns, global_step) - writer.add_scalar("train_eval/full_stddev_train_return", full_stddev_train_returns, global_step) writer.add_scalar("curriculum/completed_episodes", completed_episodes, step) diff --git a/task_space/__init__.py b/task_space/__init__.py new file mode 100644 index 00000000..1561be1c --- /dev/null +++ b/task_space/__init__.py @@ -0,0 +1 @@ +from .task_space import TaskSpace diff --git a/task_space/task_space.py b/task_space/task_space.py new file mode 100644 index 00000000..1ef674be --- /dev/null +++ b/task_space/task_space.py @@ -0,0 +1,234 @@ +import itertools +from typing import Any, List, Union + +import numpy as np +from gymnasium.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple + + +class TaskSpace(): + def __init__(self, gym_space: Union[Space, int], tasks=None): + + if not isinstance(gym_space, Space): + gym_space = self._create_gym_space(gym_space) + + self.gym_space = gym_space + + # Autogenerate task names + if tasks is None: + tasks = self._generate_task_names(gym_space) + + self._tasks = set(tasks) if tasks is not None else None + self._encoder, self._decoder = self._make_task_encoder(gym_space, tasks) + + def _create_gym_space(self, gym_space): + if isinstance(gym_space, int): + # Syntactic sugar for discrete space + gym_space = Discrete(gym_space) + elif isinstance(gym_space, tuple): + # Syntactic sugar for discrete space + gym_space = MultiDiscrete(gym_space) + elif isinstance(gym_space, list): + # Syntactic sugar for tuple space + spaces = [] + for i, value in enumerate(gym_space): + spaces[i] = self._create_gym_space(value) + gym_space = Tuple(spaces) + elif isinstance(gym_space, dict): + # Syntactic sugar for dict space + spaces = {} + for key, value in gym_space.items(): + spaces[key] = self._create_gym_space(value) + gym_space = Dict(spaces) + return gym_space + + def _generate_task_names(self, gym_space): + if isinstance(gym_space, Discrete): + tasks = tuple(range(gym_space.n)) + elif isinstance(gym_space, MultiDiscrete): + tasks = [tuple(range(dim)) for dim in gym_space.nvec] + elif isinstance(gym_space, Tuple): + tasks = [self._generate_task_names(value) for value in gym_space.spaces] + elif isinstance(gym_space, Dict): + tasks = {key: tuple(self._generate_task_names(value)) for key, value in gym_space.spaces.items()} + else: + tasks = None + return tasks + + def _make_task_encoder(self, space, tasks): + if isinstance(space, Discrete): + assert space.n == len(tasks), f"Number of tasks ({space.n}) must match number of discrete options ({len(tasks)})" + self._encode_map = {task: i for i, task in enumerate(tasks)} + self._decode_map = {i: task for i, task in enumerate(tasks)} + encoder = lambda task: self._encode_map[task] if task in self._encode_map else None + decoder = lambda task: self._decode_map[task] if task in self._decode_map else None + + elif isinstance(space, Box): + encoder = lambda task: task if space.contains(np.asarray(task, dtype=space.dtype)) else None + decoder = lambda task: task if space.contains(np.asarray(task, dtype=space.dtype)) else None + elif isinstance(space, Tuple): + + assert len(space.spaces) == len(tasks), f"Number of task ({len(space.spaces)})must match options in Tuple ({len(tasks)})" + results = [list(self._make_task_encoder(s, t)) for (s, t) in zip(space.spaces, tasks)] + encoders = [r[0] for r in results] + decoders = [r[1] for r in results] + encoder = lambda task: [e(t) for e, t in zip(encoders, task)] + decoder = lambda task: [d(t) for d, t in zip(decoders, task)] + + elif isinstance(space, MultiDiscrete): + assert len(space.nvec) == len(tasks), f"Number of steps in a tasks ({len(space.nvec)}) must match number of discrete options ({len(tasks)})" + + combinations = [p for p in itertools.product(*tasks)] + encode_map = {task: i for i, task in enumerate(combinations)} + decode_map = {i: task for i, task in enumerate(combinations)} + + encoder = lambda task: encode_map[task] if task in encode_map else None + decoder = lambda task: decode_map[task] if task in decode_map else None + + elif isinstance(space, Dict): + + def helper(task, spaces, tasks, action="encode"): + # Iteratively encodes or decodes each space in the dictionary + output = {} + if (isinstance(spaces, dict) or isinstance(spaces, Dict)): + for key, value in spaces.items(): + if (isinstance(value, dict) or isinstance(value, Dict)): + temp = helper(task[key], value, tasks[key], action) + output.update({key: temp}) + else: + encoder, decoder = self._make_task_encoder(value, tasks[key]) + output[key] = encoder(task[key]) if action == "encode" else decoder(task[key]) + return output + + encoder = lambda task: helper(task, space.spaces, tasks, "encode") + decoder = lambda task: helper(task, space.spaces, tasks, "decode") + else: + encoder = lambda task: task + decoder = lambda task: task + return encoder, decoder + + def decode(self, encoding): + """Convert the efficient task encoding to a task that can be used by the environment.""" + return self._decoder(encoding) + + def encode(self, task): + """Convert the task to an efficient encoding to speed up multiprocessing.""" + return self._encoder(task) + + def add_task(self, task): + """Add a task to the task space. Only implemented for discrete spaces.""" + if task not in self._tasks: + self._tasks.add(task) + # TODO: Increment task space size + self.gym_space = self.increase_space() + # TODO: Optimize adding tasks + self._encoder, self._decoder = self._make_task_encoder(self.gym_space, self._tasks) + + def _sum_axes(list_or_size: Union[list, int]): + if isinstance(list_or_size, int) or isinstance(list_or_size, np.int64): + return list_or_size + elif isinstance(list_or_size, list) or isinstance(list_or_size, np.ndarray): + return np.prod([TaskSpace._sum_axes(x) for x in list_or_size]) + else: + raise NotImplementedError(f"{type(list_or_size)}") + + def _enumerate_axes(self, list_or_size: Union[np.ndarray, int]): + if isinstance(list_or_size, int) or isinstance(list_or_size, np.int64): + return tuple(range(list_or_size)) + elif isinstance(list_or_size, list) or isinstance(list_or_size, np.ndarray): + return tuple(itertools.product(*[self._enumerate_axes(x) for x in list_or_size])) + else: + raise NotImplementedError(f"{type(list_or_size)}") + + def seed(self, seed): + self.gym_space.seed(seed) + + @property + def tasks(self) -> List[Any]: + # TODO: Can I just use _tasks? + return self._tasks + + def get_tasks(self, gym_space: Space = None, sample_interval: float = None) -> List[tuple]: + """ + Return the full list of discrete tasks in the task_space. + Return a sample of the tasks for continuous spaces if sample_interval is specified. + Can be overridden to exclude invalid tasks within the space. + """ + if gym_space is None: + gym_space = self.gym_space + + if isinstance(gym_space, Discrete): + return list(range(gym_space.n)) + elif isinstance(gym_space, Box): + raise NotImplementedError + elif isinstance(gym_space, Tuple): + return list(itertools.product([self.get_tasks(task_space=s) for s in gym_space.spaces])) + elif isinstance(gym_space, Dict): + return itertools.product([self.get_tasks(task_space=s) for s in gym_space.spaces.values()]) + elif isinstance(gym_space, MultiBinary): + return list(self._enumerate_axes(gym_space.nvec)) + elif isinstance(gym_space, MultiDiscrete): + return list(self._enumerate_axes(gym_space.nvec)) + elif gym_space is None: + return [] + else: + raise NotImplementedError + + @property + def num_tasks(self) -> int: + # TODO: Cache results + return self.count_tasks() + + def count_tasks(self, gym_space: Space = None) -> int: + """ + Return the number of discrete tasks in the task_space. + Returns None for continuous spaces. + Graph space not implemented. + """ + # TODO: Test these implementations + if gym_space is None: + gym_space = self.gym_space + + if isinstance(gym_space, Discrete): + return gym_space.n + elif isinstance(gym_space, Box): + return None + elif isinstance(gym_space, Tuple): + return sum([self.count_tasks(gym_space=s) for s in gym_space.spaces]) + elif isinstance(gym_space, Dict): + return sum([self.count_tasks(gym_space=s) for s in gym_space.spaces.values()]) + elif isinstance(gym_space, MultiBinary): + return TaskSpace._sum_axes(gym_space.nvec) + elif isinstance(gym_space, MultiDiscrete): + return TaskSpace._sum_axes(gym_space.nvec) + elif gym_space is None: + return 0 + else: + raise NotImplementedError(f"Unsupported task space type: {type(gym_space)}") + + def task_name(self, task): + return repr(self.decode(task)) + + def contains(self, task): + return task in self._tasks or self.decode(task) in self._tasks + + def increase_space(self, amount: Union[int, float] = 1): + if isinstance(self.gym_space, Discrete): + assert isinstance(amount, int), f"Discrete task space can only be increased by integer amount. Got {amount} instead." + return Discrete(self.gym_space.n + amount) + + def sample(self): + assert isinstance(self.gym_space, Discrete) or isinstance(self.gym_space, Box) or isinstance(self.gym_space, Dict) or isinstance(self.gym_space, Tuple) + return self.decode(self.gym_space.sample()) + + def list_tasks(self): + return list(self._tasks) + + def box_contains(self, x) -> bool: + """Return boolean specifying if x is a valid member of this space.""" + if not isinstance(x, np.ndarray): + try: + x = np.asarray(x, dtype=self.gym_space.dtype) + except (ValueError, TypeError): + return False + + return not bool(x.shape == self.gym_space.shape and np.any((x < self.gym_space.low) | (x > self.gym_space.high))) diff --git a/task_space/test_task_space.py b/task_space/test_task_space.py new file mode 100644 index 00000000..109d0a7e --- /dev/null +++ b/task_space/test_task_space.py @@ -0,0 +1,182 @@ +import gymnasium as gym +from syllabus.task_space import TaskSpace + +if __name__ == "__main__": + # Discrete Tests + task_space = TaskSpace(gym.spaces.Discrete(3), ["a", "b", "c"]) + + assert task_space.encode("a") == 0, f"Expected 0, got {task_space.encode('a')}" + assert task_space.encode("b") == 1, f"Expected 1, got {task_space.encode('b')}" + assert task_space.encode("c") == 2, f"Expected 2, got {task_space.encode('c')}" + assert task_space.encode("d") is None, f"Expected None, got {task_space.encode('d')}" + + assert task_space.decode(0) == "a", f"Expected a, got {task_space.decode(0)}" + assert task_space.decode(1) == "b", f"Expected b, got {task_space.decode(1)}" + assert task_space.decode(2) == "c", f"Expected c, got {task_space.decode(2)}" + assert task_space.decode(3) is None, f"Expected None, got {task_space.decode(3)}" + print("Discrete tests passed!") + + # MultiDiscrete Tests + task_space = TaskSpace(gym.spaces.MultiDiscrete([3, 2]), [("a", "b", "c"), (1, 0)]) + + assert task_space.encode(('a', 1)) == 0, f"Expected 0, got {task_space.encode(('a', 1))}" + assert task_space.encode(('b', 0)) == 3, f"Expected 3, got {task_space.encode(('b', 0))}" + assert task_space.encode(('c', 1)) == 4, f"Expected 4, got {task_space.encode(('c', 1))}" + + assert task_space.decode(3) == ('b', 0), f"Expected ('b', 0), got {task_space.decode(3)}" + assert task_space.decode(5) == ('c', 0), f"Expected ('c', 0), got {task_space.decode(5)}" + print("MultiDiscrete tests passed!") + + # Box Tests + task_space = TaskSpace(gym.spaces.Box(low=0, high=1, shape=(2,)), [(0, 0), (0, 1), (1, 0), (1, 1)]) + + assert task_space.encode([0.0, 0.0]) == [0.0, 0.0], f"Expected [0.0, 0.0], got {task_space.encode([0.0, 0.0])}" + assert task_space.encode([0.0, 0.1]) == [0.0, 0.1], f"Expected [0.0, 0.1], got {task_space.encode([0.0, 0.1])}" + assert task_space.encode([0.1, 0.1]) == [0.1, 0.1], f"Expected [0.1, 0.1], got {task_space.encode([0.1, 0.1])}" + assert task_space.encode([1.0, 0.1]) == [1.0, 0.1], f"Expected [1.0, 0.1], got {task_space.encode([1.0, 0.1])}" + assert task_space.encode([1.0, 1.0]) == [1.0, 1.0], f"Expected [1.0, 1.0], got {task_space.encode([1.0, 1.0])}" + assert task_space.encode([1.2, 1.0]) is None, f"Expected None, got {task_space.encode([1.2, 1.0])}" + assert task_space.encode([1.0, 1.2]) is None, f"Expected None, got {task_space.encode([1.2, 1.0])}" + assert task_space.encode([-0.1, 1.0]) is None, f"Expected None, got {task_space.encode([1.2, 1.0])}" + + assert task_space.decode([1.0, 1.0]) == [1.0, 1.0], f"Expected [1.0, 1.0], got {task_space.decode([1.0, 1.0])}" + assert task_space.decode([0.1, 0.1]) == [0.1, 0.1], f"Expected [0.1, 0.1], got {task_space.decode([0.1, 0.1])}" + assert task_space.decode([-0.1, 1.0]) is None, f"Expected None, got {task_space.decode([1.2, 1.0])}" + print("Box tests passed!") + + # Tuple Tests + task_spaces = (gym.spaces.MultiDiscrete([3, 2]), gym.spaces.Discrete(3)) + task_names = ((("a", "b", "c"), (1, 0)), ("X", "Y", "Z")) + task_space = TaskSpace(gym.spaces.Tuple(task_spaces), task_names) + + assert task_space.encode((('a', 0), 'Y')) == [1, 1], f"Expected 0, got {task_space.encode((('a', 1),'Y'))}" + assert task_space.decode([0, 1]) == [('a', 1), 'Y'], f"Expected 0, got {task_space.decode([0, 1])}" + print("Tuple tests passed!") + + # Dictionary Tests + task_spaces = gym.spaces.Dict({ + "ext_controller": gym.spaces.MultiDiscrete([5, 2, 2]), + "inner_state": gym.spaces.Dict( + { + "charge": gym.spaces.Discrete(10), + "system_checks": gym.spaces.Tuple((gym.spaces.MultiDiscrete([3, 2]), gym.spaces.Discrete(3))), + "job_status": gym.spaces.Dict( + { + "task": gym.spaces.Discrete(5), + "progress": gym.spaces.Box(low=0, high=1, shape=(2,)), + } + ), + } + ), + }) + task_names = { + "ext_controller": [("a", "b", "c", "d", "e"), (1, 0), ("X", "Y")], + "inner_state": { + "charge": [0, 1, 13, 3, 94, 35, 6, 37, 8, 9], + "system_checks": ((("a", "b", "c"), (1, 0)), ("X", "Y", "Z")), + "job_status": { + "task": ["A", "B", "C", "D", "E"], + "progress": [(0, 0), (0, 1), (1, 0), (1, 1)], + } + } + } + task_space = TaskSpace(task_spaces, task_names) + + test_val = { + "ext_controller": ('b', 1, 'X'), + 'inner_state': { + 'charge': 1, + 'system_checks': [('a', 0), 'Y'], + 'job_status': {'task': 'C', 'progress': [0.0, 0.0]} + } + } + decode_val = { + "ext_controller": 4, + "inner_state": { + "charge": 1, + "system_checks": [1, 1], + "job_status": {"progress": [0.0, 0.0], "task": 2}, + }, + } + + assert task_space.encode(test_val) == decode_val, f"Expected {decode_val}, \n but got {task_space.encode(test_val)}" + assert task_space.decode(decode_val) == test_val, f"Expected {test_val}, \n but got {task_space.decode(decode_val)}" + + test_val_2 = { + "ext_controller": ("e", 1, "Y"), + "inner_state": { + "charge": 37, + "system_checks": [("b", 0), "Z"], + "job_status": {"progress": [0.0, 0.1], "task": "D"}, + }, + } + decode_val_2 = { + "ext_controller": 17, + "inner_state": { + "charge": 7, + "system_checks": [3, 2], + "job_status": {"progress": [0.0, 0.1], "task": 3}, + }, + } + + assert task_space.encode(test_val_2) == decode_val_2, f"Expected {decode_val_2}, \n but got {task_space.encode(test_val_2)}" + assert task_space.decode(decode_val_2) == test_val_2, f"Expected {test_val_2}, \n but got {task_space.decode(decode_val_2)}" + + test_val_3 = { + "ext_controller": ("e", 1, "X"), + "inner_state": { + "charge": 8, + "system_checks": [("c", 0), "X"], + "job_status": {"progress": [0.5, 0.1], "task": "E"}, + }, + } + decode_val_3 = { + "ext_controller": 16, + "inner_state": { + "charge": 8, + "system_checks": [5, 0], + "job_status": {"progress": [0.5, 0.1], "task": 4}, + }, + } + + assert task_space.encode(test_val_3) == decode_val_3, f"Expected {decode_val_3}, \n but got {task_space.encode(test_val_3)}" + assert task_space.decode(decode_val_3) == test_val_3, f"Expected {test_val_3}, \n but got {task_space.decode(decode_val_3)}" + + print("Dictionary tests passed!") + + # Test syntactic sugar + task_space = TaskSpace(3) + assert task_space.encode(0) == 0, f"Expected 0, got {task_space.encode(0)}" + assert task_space.encode(1) == 1, f"Expected 1, got {task_space.encode(1)}" + assert task_space.encode(2) == 2, f"Expected 2, got {task_space.encode(2)}" + assert task_space.encode(3) is None, f"Expected None, got {task_space.encode(3)}" + + task_space = TaskSpace((2, 4)) + assert task_space.encode((0, 0)) == 0, f"Expected 0, got {task_space.encode((0, 0))}" + assert task_space.encode((0, 1)) == 1, f"Expected 1, got {task_space.encode((0, 1))}" + assert task_space.encode((1, 0)) == 4, f"Expected 2, got {task_space.encode((1, 0))}" + assert task_space.encode((3, 3)) is None, f"Expected None, got {task_space.encode((3, 3))}" + + task_space = TaskSpace((2, 4)) + assert task_space.encode((0, 0)) == 0, f"Expected 0, got {task_space.encode((0, 0))}" + assert task_space.encode((0, 1)) == 1, f"Expected 1, got {task_space.encode((0, 1))}" + assert task_space.encode((1, 0)) == 4, f"Expected 2, got {task_space.encode((1, 0))}" + assert task_space.encode((3, 3)) is None, f"Expected None, got {task_space.encode((3, 3))}" + + task_space = TaskSpace({"map": 5, "level": (4, 10), "difficulty": 3}) + + encoding = task_space.encode({"map": 0, "level": (0, 0), "difficulty": 0}) + expected = {"map": 0, "level": 0, "difficulty": 0} + + encoding = task_space.encode({"map": 4, "level": (3, 9), "difficulty": 2}) + expected = {"map": 4, "level": 39, "difficulty": 2} + assert encoding == expected, f"Expected {expected}, got {encoding}" + + encoding = task_space.encode({"map": 2, "level": (2, 0), "difficulty": 1}) + expected = {"map": 2, "level": 20, "difficulty": 1} + assert encoding == expected, f"Expected {expected}, got {encoding}" + + encoding = task_space.encode({"map": 5, "level": (2, 11), "difficulty": -1}) + expected = {"map": None, "level": None, "difficulty": None} + assert encoding == expected, f"Expected {expected}, got {encoding}" + print("All tests passed!") diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..f600ce6b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,4 @@ +from .sync_test_curriculum import SyncTestCurriculum +from .sync_test_env import SyncTestEnv +from .utils import * +from .determinism import test_determinism diff --git a/tests/determinism.py b/tests/determinism.py new file mode 100644 index 00000000..af4e8e86 --- /dev/null +++ b/tests/determinism.py @@ -0,0 +1,154 @@ +import numpy as np +from syllabus.tests import evaluate_random_policy + +N_EPISODES = 10 + + +def print_if_verbose(verbose, *args, **kwargs): + if verbose > 0: + print(*args, **kwargs) + + +def compare_obs(obs1, obs2): + if not isinstance(obs1, type(obs2)): + print(f"Type mismatch: {type(obs1)} != {type(obs2)}") + return False + + if isinstance(obs1, dict): + for key in obs1.keys(): + if key not in obs2: + print(f"Key {key} not in obs2") + return False + if not compare_obs(obs1[key], obs2[key]): + return False + return True + if isinstance(obs1, (list, tuple, np.ndarray)): + if len(obs1) != len(obs2): + print(f"Length mismatch: {len(obs1)} != {len(obs2)}") + return False + for i in range(len(obs1)): + if not compare_obs(obs1[i], obs2[i]): + return False + return True + else: + return obs1 == obs2 + + +def compare_episodes(make_env, task1, task2, verbose=0): + term1 = trunc1 = term2 = trunc2 = False + step = 0 + num_obs_failed = 0 + num_rews_failed = 0 + + # Set tasks + env1 = make_env() + env1.reset(new_task=task1) + env2 = make_env() + env2.reset(new_task=task2) + + # Seed spaces + env1.action_space.seed(0) + env1.observation_space.seed(0) + env1.task_space.seed(0) + env2.action_space.seed(0) + env2.observation_space.seed(0) + env2.task_space.seed(0) + + while not (term1 or trunc1 or term2 or trunc2): + action1 = env1.action_space.sample() + action2 = env2.action_space.sample() + + # Check actions + if action1 != action2: + print_if_verbose(verbose, f"Step {step}: Actions are not the same: {action1} != {action2}. Stopping test.") + return False + + obs1, rew1, term1, trunc1, info1 = env1.step(action1) + obs2, rew2, term2, trunc2, info2 = env2.step(action2) + + # Check observations + if not compare_obs(obs1, obs2): + if num_obs_failed == 0: + print_if_verbose(verbose, f"Step {step}: Obs are not the same: {obs1} != {obs2}. This message will not print for future steps.") + num_obs_failed += 1 + + # Check rewards + if rew1 != rew2: + if num_rews_failed == 0: + print_if_verbose(verbose, f"Step {step}: Rewards are not the same: {rew1} != {rew2}. This message will not print for future steps.") + num_rews_failed += 1 + + # Check terms + if term1 != term2: + print_if_verbose(verbose, f"Step {step}: Terms are not the same: {term1} != {term2}. Stopping test.") + return False + + # Check truncs + if trunc1 != trunc2: + print_if_verbose(verbose, f"Step {step}: Truncs are not the same: {trunc1} != {trunc2}. Stopping test.") + return False + + step += 1 + return num_obs_failed == 0 and num_rews_failed == 0 + + +def test_determinism(make_env, verbose=0): + # TODO: Use the task space to sampele seeds/tasks + test_env = make_env() + assert hasattr(test_env, "task_space"), "Environment does not have a task space. make_env must return a TaskEnv or use a TaskWrapper." + task_space = test_env.task_space + + print_if_verbose(verbose, "Runnning determinism tests...") + + # Test full episode returns + print_if_verbose(verbose, "\nTesting average episodic returns...") + seeds = [task_space.sample() for _ in range(N_EPISODES)] + return1, _ = evaluate_random_policy(make_env, num_episodes=N_EPISODES, seeds=seeds) + return2, _ = evaluate_random_policy(make_env, num_episodes=N_EPISODES, seeds=seeds) + full_return_test = return1 == return2 + if full_return_test: + print_if_verbose(verbose, "PASSED: Random policy returns are deterministic!") + else: + print_if_verbose(verbose, f"FAILED: Random policy returns are not deterministic! {return1} != {return2}") + + # Test individual episode returns + print_if_verbose(verbose, "\nTesting individual episode rewards...") + avg_returns, returns = evaluate_random_policy(make_env, num_episodes=N_EPISODES, seeds=[task_space.sample()] * N_EPISODES) + return_test = all(returns == avg_returns) + if return_test: + print_if_verbose(verbose, "PASSED: Episodes returns are deterministic!") + else: + print_if_verbose(verbose, f"FAILED: Episodes returns are not deterministic! {avg_returns} != {returns}") + + print_if_verbose(verbose, "\nTesting different seeds...") + task1 = task2 = task_space.sample() + while task1 == task2: + task2 = task_space.sample() + return1, _ = evaluate_random_policy(make_env, num_episodes=N_EPISODES, seeds=[task1] * N_EPISODES) + return2, _ = evaluate_random_policy(make_env, num_episodes=N_EPISODES, seeds=[task2] * N_EPISODES) + + test3 = return1 != return2 + if test3: + print_if_verbose(verbose, "PASSED: Random policy returns with different seeds are different.") + else: + print_if_verbose(verbose, f"FAILED: Random policy returns with different seeds are the same. {return1} == {return2}") + + print_if_verbose(verbose, "\nTesting actions, rewards, and observations seeds...") + task1 = task2 = task_space.sample() + step_tests_same = compare_episodes(make_env, task1, task2, verbose=verbose) + while task1 == task2: + task2 = task_space.sample() + step_tests_different = compare_episodes(make_env, task1, 2, verbose=0) + + if step_tests_same and not step_tests_different: + print_if_verbose(verbose, "PASSED: Environment returns on individual steps are deterministic with respect to seed.") + elif step_tests_different: + print_if_verbose(verbose, "FAILED: Environment returns on individual steps are deterministic even with different seeds.") + else: + print_if_verbose(verbose, "FAILED: Environment returns on individual steps are not deterministic with the same seed.") + + return { + "avg_episodic_returns": full_return_test, + "episodic_returns": return_test, + "step_values": step_tests_same + } diff --git a/tests/sync_test_curriculum.py b/tests/sync_test_curriculum.py new file mode 100644 index 00000000..45399f29 --- /dev/null +++ b/tests/sync_test_curriculum.py @@ -0,0 +1,51 @@ +import typing +from typing import Any, List, Union + +from syllabus.curricula import SequentialCurriculum + + +class SyncTestCurriculum(SequentialCurriculum): + """ + Base class and API for defining curricula to interface with Gym environments. + """ + REQUIRES_STEP_UPDATES = False + REQUIRES_EPISODE_UPDATES = True + REQUIRES_CENTRAL_UPDATES = False + + def __init__(self, num_envs, num_episodes, *curriculum_args, **curriculum_kwargs): + # Create a manual curriculum with a new task per episode, repeated across all envs + task_list = [f"task {i+1}" for i in range(num_episodes)] + stopping = [f"tasks>={num_envs}"] * (num_episodes - 1) + super().__init__(task_list, stopping, *curriculum_args, **curriculum_kwargs) + self.num_envs = num_envs + self.num_episodes = num_episodes + self.task_counts = {self.task_space.encode(task): 0 for task in task_list} + self.task_counts[0] = 0 # Error task + self.total_reward = 0 + self.total_dones = 0 + + def update_on_episode(self, episode_return, episode_len, episode_task, env_id: int = None) -> None: + super().update_on_episode(episode_return, episode_len, episode_task, env_id) + self.total_reward += episode_return + self.total_dones += 1 + self.task_counts[episode_task] += 1 + + def get_stats(self): + return { + "total_reward": self.total_reward, + "total_dones": self.total_dones, + "task_counts": self.task_counts + } + + def sample(self, k: int = 1) -> Union[List, Any]: + remaining_tasks = (self.num_episodes * self.num_envs) - self.total_tasks + if remaining_tasks < k: + tasks = super().sample(k=remaining_tasks) + [0] * (k - remaining_tasks) + else: + tasks = [] + while k > self.num_envs - self.n_tasks: + tasks += super().sample(k=self.num_envs - self.n_tasks) + k -= (self.num_envs - self.n_tasks) + if k > 0: + tasks += super().sample(k=k) + return tasks diff --git a/tests/sync_test_env.py b/tests/sync_test_env.py new file mode 100644 index 00000000..1c054228 --- /dev/null +++ b/tests/sync_test_env.py @@ -0,0 +1,33 @@ +import warnings +import gymnasium as gym +from syllabus.core import TaskEnv +from syllabus.task_space import TaskSpace + + +class SyncTestEnv(TaskEnv): + def __init__(self, num_episodes, num_steps=100): + super().__init__() + self.num_steps = num_steps + self.action_space = gym.spaces.Discrete(2) + self.observation_space = gym.spaces.Tuple((gym.spaces.Discrete(self.num_steps), gym.spaces.Discrete(2))) + self.task_space = TaskSpace(gym.spaces.Discrete(num_episodes + 1), ["error task"] + [f"task {i+1}" for i in range(num_episodes)]) + self.task = "error_task" + + def reset(self, new_task=None): + if new_task == "error task": + warnings.warn("Received error task. This likely means that too many tasks are being requested.") + if new_task is None: + warnings.warn("No task provided. Resetting to error task.") + self.task = new_task + self._turn = 0 + return (self._turn, None), {"content": "reset", "task": self.task} + + def step(self, action): + self._turn += 1 + + obs = self.observation((self._turn, action)) + rew = 1 + term = self._turn >= self.num_steps + trunc = False + info = {"content": "step", "task_completion": self._task_completion(obs, rew, term, trunc, {})} + return obs, rew, term, trunc, info diff --git a/tests/test_stat_recorder.py b/tests/test_stat_recorder.py new file mode 100644 index 00000000..89144e46 --- /dev/null +++ b/tests/test_stat_recorder.py @@ -0,0 +1,89 @@ +import gymnasium as gym +import numpy as np +from collections import deque +from syllabus.task_space import TaskSpace +from syllabus.core import StatRecorder + +def test(stat_recorder, simulated_eps_info, expected_episode_reward_mean_by_task, expected_episode_reward_std_by_task, expected_episode_length_mean_by_task, expected_episode_length_std_by_task, expected_normalized_reward): + num_passed = 0 + num_failed = 0 + for i in range(len(simulated_eps_info)): + task_id, eps_reward, eps_length = simulated_eps_info[i] + stat_recorder.record(eps_reward, eps_length, task_id) + rewards_for_task = deque([x[1] for x in simulated_eps_info[:i+1] if x[0]==task_id], maxlen=10) + normalized_reward_for_task = stat_recorder.normalize(rewards_for_task, task_id) + try: + assert abs(stat_recorder.stats[task_id]['mean_r'] - expected_episode_reward_mean_by_task[i]) < 1e-7 + assert abs(np.sqrt(stat_recorder.stats[task_id]['var_r']) - expected_episode_reward_std_by_task[i]) < 1e-7 + assert abs(stat_recorder.stats[task_id]['mean_l'] - expected_episode_length_mean_by_task[i]) < 1e-7 + assert abs(np.sqrt(stat_recorder.stats[task_id]['var_l']) - expected_episode_length_std_by_task[i]) < 1e-7 + for j in range(len(normalized_reward_for_task)): + assert abs(normalized_reward_for_task[j] - expected_normalized_reward[i][j]) < 1e-7 + print(f"Test case {i} PASSED.") + num_passed += 1 + except AssertionError: + print(f"Test case {i} FAILED.") + num_failed += 1 + print(f"{len(simulated_eps_info)} tests total, {num_passed} tests passed, {num_failed} test failed. Pass rate: {num_passed / len(simulated_eps_info) * 100}%.\n") + +def main(): + """ + simulated_eps_info: A list of tuples simulateing episodic infomation. + Each tuple satisfies the format (task_id, episode_return, episode_length). + """ + simulated_eps_info = [(0, 5.0, 48), (1, 1.0, 75), (2, 2.0, 36), (2, 4.0, 65), (0, 1.0, 54), (1, 3.0, 82), + (0, 2.0, 39), (2, 3.0, 80), (0, 4.0, 57), (1, 0.0, 94), (1, 2.0, 64), (0, 0.0, 45), + (2, 1.0, 86), (0, 2.0, 68), (1, 2.0, 92), (2, 1.0, 71), (0, 3.0, 32), (2, 1.0, 47)] + task_space = TaskSpace(gym.spaces.Discrete(3), list(np.arange(0, 3))) + + """ + Testing StatRecorder by calculating running average. + """ + print("Testing StatRecorder by calculating running average: ") + expected_episode_reward_mean_by_task = [5.0, 1.0, 2.0, 3.0, 3.0, 2.0, + 2.6666667, 3.0, 3.0, 1.3333333, 1.5, 2.4, + 2.5, 2.3333333, 1.6, 2.2, 2.4285714, 2.0 + ] + expected_episode_reward_std_by_task = [0.0, 0.0, 0.0, 1.0, 2.0, 1.0, + 1.6996732, 0.8164966, 1.5811388, 1.2472191, 1.118034, 1.8547237, + 1.118034, 1.6996732, 1.0198039, 1.1661904, 1.5907898, 1.1547005] + expected_episode_length_mean_by_task = [48, 75, 36, 50.5, 51, 78.5, + 47, 60.3333333, 49.5, 83.6666667, 78.75, 48.6, + 66.75, 51.8333333, 81.4, 67.6, 49, 64.1666667] + expected_episode_length_std_by_task = [0.0, 0.0, 0.0, 14.5, 3.0, 3.5, + 6.164414, 18.2635034, 6.8738635, 7.8457349, 10.8943793, 6.406247, 19.3309984, 9.2990442, 11.0923397, 17.3735431, 11.0582871, 17.620222] + expected_normalized_reward = [deque([0.0]), deque([0.0]), deque([0.0]), deque([-1.0, 1.0]), deque([1.0, -1.0]), deque([-1.0, 1.0]), + deque([1.3728129, -0.9805807, -0.3922323]), deque([-1.2247449, 1.2247449, 0.0]), deque([1.2649111, -1.2649111, -0.6324555, 0.6324555]), deque([-0.2672612, 1.3363062, -1.069045]), deque([-0.4472136, 1.3416408, -1.3416408, 0.4472136]), deque([1.4018261, -0.7548294, -0.2156655, 0.8626622, -1.2939933]), + deque([-0.4472136, 1.3416408, 0.4472136, -1.3416408]), deque([1.5689291, -0.7844645, -0.1961161, 0.9805807, -1.3728129, -0.1961161]), deque([-0.5883484, 1.3728129, -1.5689291, 0.3922323, 0.3922323]), deque([-0.1714986, 1.5434873, 0.6859943, -1.0289915, -1.0289915]), deque([1.6164477, -0.8980265, -0.269408, 0.9878292, -1.5266451, -0.269408, 0.3592106]), deque([0.0, 1.7320508, 0.8660254, -0.8660254, -0.8660254, -0.8660254]) + ] + + stat_recorder = StatRecorder(task_space) + test(stat_recorder, simulated_eps_info, expected_episode_reward_mean_by_task, expected_episode_reward_std_by_task, expected_episode_length_mean_by_task, expected_episode_length_std_by_task, expected_normalized_reward) + + """ + Testing StatRecorder by calculating based on last n episodes. + """ + calc_past_n = 3 + print(f"Testing StatRecorder by calculating based on last {calc_past_n} episodes: ") + expected_episode_reward_mean_by_task = [5.0, 1.0, 2.0, 3.0, 3.0, 2.0, + 2.6666667, 3.0, 2.3333333, 1.3333333, 1.6666667, 2.0, + 2.6666667, 2.0, 1.3333333, 1.6666667, 1.6666667, 1.0] + expected_episode_reward_std_by_task = [0.0, 0.0, 0.0, 1.0, 2.0, 1.0, + 1.6996732, 0.8164966, 1.2472191, 1.2472191, 1.2472191, 1.6329932, + 1.2472191, 1.6329932, 0.942809, 0.942809, 1.2472191, 0.0] + expected_episode_length_mean_by_task = [48, 75, 36, 50.5, 51, 78.5, + 47, 60.3333333, 50, 83.6666667, 80, 47, + 77, 56.6666667, 83.3333333, 79, 48.3333333, 68] + expected_episode_length_std_by_task = [0.0, 0.0, 0.0, 14.5, 3.0, 3.5, + 6.164414, 18.2635034, 7.8740079, 7.8457349, 12.328828, 7.4833148, + 8.8317609, 9.3926685, 13.6950924, 6.164414, 14.8847424, 16.0623784] + expected_normalized_reward = [deque([0.0]), deque([0.0]), deque([0.0]), deque([-1.0, 1.0]), deque([1.0, -1.0]), deque([-1.0, 1.0]), + deque([1.3728129, -0.9805807, -0.3922323]), deque([-1.2247449, 1.2247449, 0.0]), deque([2.1380899, -1.069045, -0.2672612, 1.3363062]), deque([-0.2672612, 1.3363062, -1.069045]), deque([-0.5345225, 1.069045, -1.3363062, 0.2672612]), deque([1.8371173, -0.6123724, 0.0, 1.2247449, -1.2247449]), + deque([-0.5345225, 1.069045, 0.2672612, -1.3363062]), deque([1.8371173, -0.6123724, 0.0, 1.2247449, -1.2247449, 0.0]), deque([-0.3535534, 1.767767, -1.4142136, 0.7071068, 0.7071068]), deque([0.3535534, 2.4748737, 1.4142136, -0.7071068, -0.7071068]), deque([2.6726124, -0.5345225, 0.2672612, 1.8708287, -1.3363062, 0.2672612, 1.069045]), deque([100.0, 300.0, 200.0, 0.0, 0.0, 0.0]) + ] + + stat_recorder = StatRecorder(task_space, calc_past_n) + test(stat_recorder, simulated_eps_info, expected_episode_reward_mean_by_task, expected_episode_reward_std_by_task, expected_episode_length_mean_by_task, expected_episode_length_std_by_task, expected_normalized_reward) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..98bac823 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,278 @@ +import time +import warnings +from multiprocessing import Process + +import gym as openai_gym +import gymnasium as gym +import numpy as np +import ray +import torch +from shimmy.openai_gym_compatibility import GymV21CompatibilityV0 + +from syllabus.core import MultiProcessingSyncWrapper, RaySyncWrapper, ReinitTaskWrapper +from syllabus.examples.task_wrappers.cartpole_task_wrapper import CartPoleTaskWrapper +from syllabus.task_space import TaskSpace +from syllabus.tests import SyncTestEnv + + +def evaluate_random_policy(make_env, num_episodes=100, seeds=None): + env = make_env(seed=seeds[0] if seeds else None) + + # Seed environment + env.action_space.seed(0) + env.observation_space.seed(0) + + episode_returns = [] + + for i in range(num_episodes): + episode_return = 0 + if seeds: + _ = env.reset(new_task=seeds[i]) + env.action_space.seed(0) + env.observation_space.seed(0) + else: + _ = env.reset() + term = trunc = False + while not (term or trunc): + action = env.action_space.sample() + _, rew, term, trunc, _ = env.step(action) + episode_return += rew + episode_returns.append(episode_return) + + avg_return = sum(episode_returns) / len(episode_returns) + # print(f"Average Episodic Return: {avg_return}") + return avg_return, episode_returns + + +def run_episode(env, new_task=None, curriculum=None, env_id=0): + """Run a single episode of the environment.""" + if new_task is not None: + obs = env.reset(new_task=new_task) + else: + obs = env.reset() + term = trunc = False + ep_rew = 0 + ep_len = 0 + while not (term or trunc): + action = env.action_space.sample() + obs, rew, term, trunc, info = env.step(action) + if curriculum and curriculum.requires_step_updates: + curriculum.update_on_step(env.task_space.encode(env.task), obs, rew, term, trunc, info, env_id=env_id) + curriculum.update_task_progress(env.task_space.encode(env.task), info["task_completion"], env_id=env_id) + ep_rew += rew + ep_len += 1 + if curriculum and curriculum.requires_episode_updates: + curriculum.update_on_episode(ep_rew, ep_len, env.task_space.encode(env.task), env_id=env_id) + return ep_rew + + +def run_set_length(env, curriculum=None, episodes=None, steps=None, env_id=0, env_outputs=None): + """Run environment for a set number of episodes or steps.""" + assert episodes is not None or steps is not None, "Must specify either episodes or steps." + assert episodes is None or steps is None, "Cannot specify both episodes and steps." + total_episodes = episodes if episodes is not None else 2**16 - 1 + total_steps = steps if steps is not None else 2**16 - 1 + n_steps = 0 + n_episodes = 0 + + # Resume stepping from the last observation. + if env_outputs is None: + obs = env.reset(new_task=curriculum.sample()[0] if curriculum else None) + + while n_episodes < total_episodes and n_steps < total_steps: + term = trunc = False + ep_rew = 0 + ep_len = 0 + while not (term or trunc) and n_steps < total_steps: + action = env.action_space.sample() + obs, rew, term, trunc, info = env.step(action) + if curriculum and curriculum.requires_step_updates: + curriculum.update_on_step(env.task_space.encode(env.task), obs, rew, term, trunc, info, env_id=env_id) + curriculum.update_task_progress(env.task_space.encode(env.task), info["task_completion"], env_id=env_id) + ep_rew += rew + ep_len += 1 + n_steps += 1 + if (term or trunc) and curriculum and curriculum.requires_episode_updates: + curriculum.update_on_episode(ep_rew, ep_len, env.task_space.encode(env.task), env_id=env_id) + n_episodes += 1 + obs = env.reset(new_task=curriculum.sample()[0] if curriculum else None) + + return (obs, rew, term, trunc, info) + + +def run_episodes(env_fn, env_args, env_kwargs, curriculum=None, num_episodes=10, env_id=0): + """Run multiple episodes of the environment.""" + env = env_fn(env_args=env_args, env_kwargs=env_kwargs) + ep_rews = [] + for _ in range(num_episodes): + if curriculum: + task = env.task_space.decode(curriculum.sample()[0]) + ep_rews.append(run_episode(env, new_task=task, curriculum=curriculum, env_id=env_id)) + else: + ep_rews.append(run_episode(env)) + env.close() + + +def run_episodes_queue(env_fn, env_args, env_kwargs, curriculum_components, sync=True, num_episodes=10, update_on_step=True, buffer_size=2, env_id=0): + env = env_fn(curriculum_components, env_args=env_args, env_kwargs=env_kwargs, type="queue", update_on_step=update_on_step, buffer_size=buffer_size) if sync else env_fn(env_args=env_args, env_kwargs=env_kwargs) + ep_rews = [] + for _ in range(num_episodes): + ep_rews.append(run_episode(env, env_id=env_id)) + env.close() + + +@ray.remote +def run_episodes_ray(env_fn, env_args, env_kwargs, sync=True, num_episodes=10, update_on_step=True): + env = env_fn(env_args=env_args, env_kwargs=env_kwargs, type="ray", update_on_step=update_on_step) if sync else env_fn(env_args=env_args, env_kwargs=env_kwargs) + ep_rews = [] + for _ in range(num_episodes): + ep_rews.append(run_episode(env)) + env.close() + + +def run_single_process(env_fn, env_args=(), env_kwargs={}, curriculum=None, num_envs=2, num_episodes=10): + start = time.time() + for _ in range(num_envs): + run_episodes(env_fn, env_args, env_kwargs, curriculum=curriculum, num_episodes=num_episodes) + end = time.time() + native_speed = end - start + return native_speed + + +def run_native_multiprocess(env_fn, env_args=(), env_kwargs={}, curriculum=None, num_envs=2, num_episodes=10, update_on_step=True, buffer_size=2): + start = time.time() + # Choose multiprocessing and curriculum methods + if curriculum: + target = run_episodes_queue + args = (env_fn, env_args, env_kwargs, curriculum.get_components(), True, num_episodes, update_on_step and curriculum.curriculum.requires_step_updates, buffer_size) + else: + target = run_episodes + args = (env_fn, env_args, env_kwargs, (), num_episodes) + + # Run episodes + actors = [] + for i in range(num_envs): + nargs = args + (i,) + actors.append(Process(target=target, args=nargs)) + for actor in actors: + actor.start() + for actor in actors: + actor.join() + end = time.time() + native_speed = end - start + + # Stop curriculum to prevent it from slowing down the next test + if curriculum: + curriculum.stop() + return native_speed + + +def run_ray_multiprocess(env_fn, env_args=(), env_kwargs={}, curriculum=None, num_envs=2, num_episodes=10, update_on_step=True): + if curriculum: + target = run_episodes_ray + args = (env_fn, env_args, env_kwargs, True, num_episodes, update_on_step) + else: + target = run_episodes_ray + args = (env_fn, env_args, env_kwargs, False, num_episodes, update_on_step) + + start = time.time() + remotes = [] + for _ in range(num_envs): + remotes.append(target.remote(*args)) + ray.get(remotes) + end = time.time() + ray_speed = end - start + if curriculum: + ray.kill(curriculum.curriculum) + return ray_speed + +def get_test_values(x): + return torch.unsqueeze(torch.Tensor(np.array([0] * len(x))), -1) + + +# Sync Test Environment +def create_synctest_env(*args, type=None, env_args=(), env_kwargs={}, **kwargs): + env = SyncTestEnv(*env_args, **env_kwargs) + if type == "queue": + env = MultiProcessingSyncWrapper(env, *args, task_space=env.task_space, **kwargs) + elif type == "ray": + env = RaySyncWrapper(env, *args, task_space=env.task_space, **kwargs) + return env + + +# Cartpole Tests +def create_cartpole_env(*args, type=None, env_args=(), env_kwargs={}, **kwargs): + env = gym.make("CartPole-v1", **env_kwargs) + env = CartPoleTaskWrapper(env) + + if type == "queue": + env = MultiProcessingSyncWrapper(env, *args, task_space=env.task_space, **kwargs) + elif type == "ray": + env = RaySyncWrapper(env, *args, task_space=env.task_space, **kwargs) + return env + + +# Nethack Tests +def create_nethack_env(*args, type=None, env_args=(), env_kwargs={}, **kwargs): + try: + from nle.env.tasks import NetHackScore + + from syllabus.examples.task_wrappers.nethack_wrappers import \ + NethackTaskWrapper + except ImportError: + warnings.warn("Unable to import nle.") + + env = NetHackScore(*env_args, **env_kwargs) + env = NethackTaskWrapper(env) + + if type == "queue": + env = MultiProcessingSyncWrapper( + env, *args, task_space=env.task_space, **kwargs + ) + elif type == "ray": + env = RaySyncWrapper(env, *args, task_space=env.task_space, **kwargs) + return env + + +# Procgen Tests +def create_procgen_env(*args, type=None, env_args=(), env_kwargs={}, **kwargs): + try: + import procgen + + from syllabus.examples.task_wrappers.procgen_task_wrapper import \ + ProcgenTaskWrapper + except ImportError: + warnings.warn("Unable to import procgen.") + + env = openai_gym.make("procgen-bigfish-v0", *env_args, **env_kwargs) + env = GymV21CompatibilityV0(env=env) + env = ProcgenTaskWrapper(env, "bigfish") + + if type == "queue": + env = MultiProcessingSyncWrapper( + env, *args, task_space=env.task_space, **kwargs + ) + elif type == "ray": + env = RaySyncWrapper(env, *args, task_space=env.task_space, **kwargs) + return env + + +# Minigrid Tests +def create_minigrid_env(*args, type=None, env_args=(), env_kwargs={}, **kwargs): + try: + from gym_minigrid.envs import DoorKeyEnv # noqa: F401 + from gym_minigrid.register import env_list + except ImportError: + warnings.warn("Unable to import gym_minigrid.") + env = gym.make("MiniGrid-DoorKey-5x5-v0", **env_kwargs) + + def create_env(task): + return gym.make(task) + + task_space = TaskSpace(gym.spaces.Discrete(len(env_list)), env_list) + env = ReinitTaskWrapper(env, create_env, task_space=task_space) + if type == "queue": + env = MultiProcessingSyncWrapper(env, *args, task_space=env.task_space, **kwargs) + elif type == "ray": + env = RaySyncWrapper(env, *args, task_space=env.task_space, **kwargs) + return env