From 78f450ab3fe0e4d421b3cdcc2d996c5731c1803c Mon Sep 17 00:00:00 2001 From: Ryan Sullivan Date: Tue, 23 Apr 2024 23:19:45 -0400 Subject: [PATCH 01/14] Remove unused eval --- .../cleanrl_procgen_centralplr.py | 52 ---------------- .../training_scripts/cleanrl_procgen_plr.py | 62 ++----------------- 2 files changed, 5 insertions(+), 109 deletions(-) diff --git a/syllabus/examples/training_scripts/cleanrl_procgen_centralplr.py b/syllabus/examples/training_scripts/cleanrl_procgen_centralplr.py index b848d693..704dfcae 100644 --- a/syllabus/examples/training_scripts/cleanrl_procgen_centralplr.py +++ b/syllabus/examples/training_scripts/cleanrl_procgen_centralplr.py @@ -150,46 +150,6 @@ def wrap_vecenv(vecenv): return vecenv -def full_level_replay_evaluate( - env_name, - policy, - num_episodes, - device, - num_levels=1 # Not used -): - policy.eval() - - eval_envs = ProcgenEnv( - num_envs=args.num_eval_episodes, env_name=env_name, num_levels=1, start_level=0, distribution_mode="easy", paint_vel_info=False - ) - eval_envs = VecExtractDictObs(eval_envs, "rgb") - eval_envs = wrap_vecenv(eval_envs) - - # Seed environments - seeds = [int.from_bytes(os.urandom(3), byteorder="little") for _ in range(num_episodes)] - for i, seed in enumerate(seeds): - eval_envs.seed(seed, i) - - eval_obs, _ = eval_envs.reset() - eval_episode_rewards = [-1] * num_episodes - - while -1 in eval_episode_rewards: - with torch.no_grad(): - eval_action, _, _, _ = policy.get_action_and_value(torch.Tensor(eval_obs).to(device), deterministic=False) - - eval_obs, _, truncs, terms, infos = eval_envs.step(eval_action.cpu().numpy()) - for i, info in enumerate(infos): - if 'episode' in info.keys() and eval_episode_rewards[i] == -1: - eval_episode_rewards[i] = info['episode']['r'] - - mean_returns = np.mean(eval_episode_rewards) - stddev_returns = np.std(eval_episode_rewards) - env_min, env_max = PROCGEN_RETURN_BOUNDS[args.env_id] - normalized_mean_returns = (mean_returns - env_min) / (env_max - env_min) - policy.train() - return mean_returns, stddev_returns, normalized_mean_returns - - def level_replay_evaluate( env_name, policy, @@ -493,15 +453,9 @@ def level_replay_evaluate( mean_eval_returns, stddev_eval_returns, normalized_mean_eval_returns = level_replay_evaluate( args.env_id, agent, args.num_eval_episodes, device, num_levels=0 ) - full_mean_eval_returns, full_stddev_eval_returns, full_normalized_mean_eval_returns = full_level_replay_evaluate( - args.env_id, agent, args.num_eval_episodes, device, num_levels=0 - ) mean_train_returns, stddev_train_returns, normalized_mean_train_returns = level_replay_evaluate( args.env_id, agent, args.num_eval_episodes, device, num_levels=200 ) - full_mean_train_returns, full_stddev_train_returns, full_normalized_mean_train_returns = full_level_replay_evaluate( - args.env_id, agent, args.num_eval_episodes, device, num_levels=200 - ) # TRY NOT TO MODIFY: record rewards for plotting purposes writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) @@ -519,16 +473,10 @@ def level_replay_evaluate( writer.add_scalar("test_eval/mean_episode_return", mean_eval_returns, global_step) writer.add_scalar("test_eval/normalized_mean_eval_return", normalized_mean_eval_returns, global_step) writer.add_scalar("test_eval/stddev_eval_return", stddev_eval_returns, global_step) - writer.add_scalar("test_eval/full_mean_episode_return", full_mean_eval_returns, global_step) - writer.add_scalar("test_eval/full_normalized_mean_eval_return", full_normalized_mean_eval_returns, global_step) - writer.add_scalar("test_eval/full_stddev_eval_return", full_stddev_eval_returns, global_step) writer.add_scalar("train_eval/mean_episode_return", mean_train_returns, global_step) writer.add_scalar("train_eval/normalized_mean_train_return", normalized_mean_train_returns, global_step) writer.add_scalar("train_eval/stddev_train_return", stddev_train_returns, global_step) - writer.add_scalar("train_eval/full_mean_episode_return", full_mean_train_returns, global_step) - writer.add_scalar("train_eval/full_normalized_mean_train_return", full_normalized_mean_train_returns, global_step) - writer.add_scalar("train_eval/full_stddev_train_return", full_stddev_train_returns, global_step) writer.add_scalar("curriculum/completed_episodes", completed_episodes, step) diff --git a/syllabus/examples/training_scripts/cleanrl_procgen_plr.py b/syllabus/examples/training_scripts/cleanrl_procgen_plr.py index dabcd500..8536cfed 100644 --- a/syllabus/examples/training_scripts/cleanrl_procgen_plr.py +++ b/syllabus/examples/training_scripts/cleanrl_procgen_plr.py @@ -150,46 +150,6 @@ def wrap_vecenv(vecenv): return vecenv -def full_level_replay_evaluate( - env_name, - policy, - num_episodes, - device, - num_levels=1 # Not used -): - policy.eval() - - eval_envs = ProcgenEnv( - num_envs=args.num_eval_episodes, env_name=env_name, num_levels=1, start_level=0, distribution_mode="easy", paint_vel_info=False - ) - eval_envs = VecExtractDictObs(eval_envs, "rgb") - eval_envs = wrap_vecenv(eval_envs) - - # Seed environments - seeds = [int.from_bytes(os.urandom(3), byteorder="little") for _ in range(num_episodes)] - for i, seed in enumerate(seeds): - eval_envs.seed(seed, i) - - eval_obs, _ = eval_envs.reset() - eval_episode_rewards = [-1] * num_episodes - - while -1 in eval_episode_rewards: - with torch.no_grad(): - eval_action, _, _, _ = policy.get_action_and_value(torch.Tensor(eval_obs).to(device), deterministic=False) - - eval_obs, _, truncs, terms, infos = eval_envs.step(eval_action.cpu().numpy()) - for i, info in enumerate(infos): - if 'episode' in info.keys() and eval_episode_rewards[i] == -1: - eval_episode_rewards[i] = info['episode']['r'] - - mean_returns = np.mean(eval_episode_rewards) - stddev_returns = np.std(eval_episode_rewards) - env_min, env_max = PROCGEN_RETURN_BOUNDS[args.env_id] - normalized_mean_returns = (mean_returns - env_min) / (env_max - env_min) - policy.train() - return mean_returns, stddev_returns, normalized_mean_returns - - def level_replay_evaluate( env_name, policy, @@ -296,11 +256,11 @@ def get_value(obs): print("Using sequential curriculum.") curricula = [] stopping = [] - for i in range(199): - curricula.append(i + 1) - stopping.append("steps>=50000") - curricula.append(list(range(i + 1))) - stopping.append("steps>=50000") + for i in range(0, 199, 10): + curricula.append(list(range(i, i+10))) + stopping.append("steps>=500000") + curricula.append(list(range(i + 10))) + stopping.append("steps>=500000") curriculum = SequentialCurriculum(curricula, stopping[:-1], sample_env.task_space) else: raise ValueError(f"Unknown curriculum method {args.curriculum_method}") @@ -485,15 +445,9 @@ def get_value(obs): mean_eval_returns, stddev_eval_returns, normalized_mean_eval_returns = level_replay_evaluate( args.env_id, agent, args.num_eval_episodes, device, num_levels=0 ) - full_mean_eval_returns, full_stddev_eval_returns, full_normalized_mean_eval_returns = full_level_replay_evaluate( - args.env_id, agent, args.num_eval_episodes, device, num_levels=0 - ) mean_train_returns, stddev_train_returns, normalized_mean_train_returns = level_replay_evaluate( args.env_id, agent, args.num_eval_episodes, device, num_levels=200 ) - full_mean_train_returns, full_stddev_train_returns, full_normalized_mean_train_returns = full_level_replay_evaluate( - args.env_id, agent, args.num_eval_episodes, device, num_levels=200 - ) # TRY NOT TO MODIFY: record rewards for plotting purposes writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) @@ -511,16 +465,10 @@ def get_value(obs): writer.add_scalar("test_eval/mean_episode_return", mean_eval_returns, global_step) writer.add_scalar("test_eval/normalized_mean_eval_return", normalized_mean_eval_returns, global_step) writer.add_scalar("test_eval/stddev_eval_return", stddev_eval_returns, global_step) - writer.add_scalar("test_eval/full_mean_episode_return", full_mean_eval_returns, global_step) - writer.add_scalar("test_eval/full_normalized_mean_eval_return", full_normalized_mean_eval_returns, global_step) - writer.add_scalar("test_eval/full_stddev_eval_return", full_stddev_eval_returns, global_step) writer.add_scalar("train_eval/mean_episode_return", mean_train_returns, global_step) writer.add_scalar("train_eval/normalized_mean_train_return", normalized_mean_train_returns, global_step) writer.add_scalar("train_eval/stddev_train_return", stddev_train_returns, global_step) - writer.add_scalar("train_eval/full_mean_episode_return", full_mean_train_returns, global_step) - writer.add_scalar("train_eval/full_normalized_mean_train_return", full_normalized_mean_train_returns, global_step) - writer.add_scalar("train_eval/full_stddev_train_return", full_stddev_train_returns, global_step) writer.add_scalar("curriculum/completed_episodes", completed_episodes, step) From 9370a1c28741bd0985eb1791ad7707d7fa12a48a Mon Sep 17 00:00:00 2001 From: allisony9 Date: Tue, 23 Apr 2024 23:06:58 -0400 Subject: [PATCH 02/14] merge in upstram --- syllabus/core/curriculum_base.py | 18 ++++- syllabus/core/curriculum_sync_wrapper.py | 6 ++ syllabus/core/stat_recorder.py | 78 +++++++++++++++++++ syllabus/curricula/plr/plr_wrapper.py | 19 +++++ .../training_scripts/cleanrl_procgen_plr.py | 29 ++++++- 5 files changed, 144 insertions(+), 6 deletions(-) create mode 100644 syllabus/core/stat_recorder.py diff --git a/syllabus/core/curriculum_base.py b/syllabus/core/curriculum_base.py index 4ca9aeb0..5e75faff 100644 --- a/syllabus/core/curriculum_base.py +++ b/syllabus/core/curriculum_base.py @@ -6,6 +6,7 @@ from gymnasium.spaces import Dict from syllabus.task_space import TaskSpace +from .stat_recorder import StatRecorder # TODO: Move non-generic logic to Uniform class. Allow subclasses to call super for generic error handling @@ -28,6 +29,7 @@ def __init__(self, task_space: TaskSpace, random_start_tasks: int = 0, task_name self.completed_tasks = 0 self.task_names = task_names self.n_updates = 0 + self.stat_recorder = StatRecorder(self.task_space) if self.num_tasks == 0: warnings.warn("Task space is empty. This will cause errors during sampling if no tasks are added.") @@ -74,6 +76,7 @@ def update_task_progress(self, task: typing.Any, progress: Tuple[float, bool], e :param task: Task for which progress is being updated. :param progress: Progress toward completion or success rate of the given task. 1.0 or True typically indicates a complete task. """ + self.completed_tasks += 1 def update_on_step(self, task: typing.Any, obs: typing.Any, rew: float, term: bool, trunc: bool, info: dict, env_id: int = None) -> None: @@ -96,19 +99,26 @@ def update_on_step_batch(self, step_results: List[typing.Tuple[Any, Any, int, in :param step_results: List of step results """ - tasks, obs, rews, terms, truncs, infos = tuple(step_results) + + #obs, rews, terms, truncs, infos = tuple(step_results) + obs = [r[0] for r in step_results] + rews = [r[1] for r in step_results] + terms = [r[2] for r in step_results] + truncs = [r[3] for r in step_results] + infos = [r[4] for r in step_results] + for i in range(len(obs)): self.update_on_step(tasks[i], obs[i], rews[i], terms[i], truncs[i], infos[i], env_id=env_id) - def update_on_episode(self, episode_return: float, episode_length: int, episode_task: Any, env_id: int = None) -> None: + def update_on_episode(self, episode_return: float, episode_length: int, episode_task, env_id=None) -> None: """Update the curriculum with episode results from the environment. :param episode_return: Episodic return :param trajectory: trajectory of (s, a, r, s, ...), defaults to None :raises NotImplementedError: """ - # TODO: Add update_on_episode option similar to update-on_step - pass + self.stat_recorder.record(episode_return, episode_length, episode_task, env_id) + #raise NotImplementedError("Not yet implemented.") def update_on_demand(self, metrics: Dict): """Update the curriculum with arbitrary inputs. diff --git a/syllabus/core/curriculum_sync_wrapper.py b/syllabus/core/curriculum_sync_wrapper.py index f9866438..7b7a75f7 100644 --- a/syllabus/core/curriculum_sync_wrapper.py +++ b/syllabus/core/curriculum_sync_wrapper.py @@ -55,6 +55,9 @@ def log_metrics(self, writer, step=None): def update_on_step_batch(self, step_results): self.curriculum.update_on_step_batch(step_results) + def update_on_episode(self, episode_return, episode_length, episode_task, env_id=None): + self.curriculum.update_on_episode(episode_return, episode_length, episode_task, env_id) + def update(self, metrics): self.curriculum.update(metrics) @@ -234,6 +237,9 @@ def _update_queues(self): else: time.sleep(0.01) + def update_on_episode(self, episode_return, episode_length, episode_task, env_id=None): + super().update_on_episode(episode_return, episode_length, episode_task, env_id) + def log_metrics(self, writer, step=None): super().log_metrics(writer, step=step) if self.get_components()._debug: diff --git a/syllabus/core/stat_recorder.py b/syllabus/core/stat_recorder.py new file mode 100644 index 00000000..ea5f41a6 --- /dev/null +++ b/syllabus/core/stat_recorder.py @@ -0,0 +1,78 @@ +import numpy as np +from syllabus.task_space import TaskSpace +from gymnasium.spaces import Discrete #, MultiDiscrete? +import json +import os + +def convert_numpy(obj): + if isinstance(obj, np.generic): + return obj.item() # Use .item() to convert numpy types to native Python types + raise TypeError + +class StatRecorder: + """ + Individual stat tracking for each task. + """ + + def __init__(self, task_space: TaskSpace): + """Initialize the StatRecorder""" + + self.write_path = '/Users/allisonyang/Downloads' + + self.task_space = task_space + + assert isinstance(self.task_space, TaskSpace), f"task_space must be a TaskSpace object. Got {type(task_space)} instead." + assert isinstance(self.task_space.gym_space, Discrete), f"assume that task_space should be of type Discrete" + + self.tasks = self.task_space.get_tasks() + self.num_tasks = self.task_space.num_tasks + + self.records = {task: [] for task in self.tasks} + self.stats = {task: {} for task in self.tasks} + + def record(self, episode_return: float, episode_length: int, episode_task, env_id=None): + """ + Records the length and return of an episode for a given task. + + :param task: Identifier for the task + :param episode_length: Length of the episode, i.e. the total number of steps taken during the episode + :param episodic_return: Total return for the episode + """ + + if episode_task in self.tasks: + self.records[episode_task].append({ + "r": episode_return, + "l": episode_length, + "env_id": env_id + }) + self.stats[episode_task]['mean_r'] = np.mean([record["r"] for record in self.records[episode_task]]) + self.stats[episode_task]['var_r'] = np.var([record["r"] for record in self.records[episode_task]]) + self.stats[episode_task]['mean_l'] = np.mean([record["l"] for record in self.records[episode_task]]) + self.stats[episode_task]['var_l'] = np.var([record["l"] for record in self.records[episode_task]]) + else: + raise ValueError("Unknown task") + + """ + records = json.dumps(self.records, default=convert_numpy) + with open(os.path.join(self.write_path, 'records.json'), "w") as file: + file.write(records) + stats = json.dumps(self.stats, default=convert_numpy) + with open(os.path.join(self.write_path, 'stats.json'), "w") as file: + file.write(stats) + """ + + def get_task_return_avg(self, task): + """Returns the average episode length for a given task.""" + return np.mean([record["r"] for record in self.records[task]]) + + def get_task_return_sum(self, task): + """Returns the total return for a given task.""" + return sum([record["r"] for record in self.records[task]]) + + def get_task_return_variance(self, task): + """Returns the variance of returns for a given task.""" + return np.var([record["r"] for record in self.records[task]]) + + def get_task_return_std(self, task): + """Returns the standard deviation of returns for a given task.""" + return np.std([record["r"] for record in self.records[task]]) \ No newline at end of file diff --git a/syllabus/curricula/plr/plr_wrapper.py b/syllabus/curricula/plr/plr_wrapper.py index 9c808ddc..f4dae1f4 100644 --- a/syllabus/curricula/plr/plr_wrapper.py +++ b/syllabus/curricula/plr/plr_wrapper.py @@ -264,6 +264,25 @@ def update_on_step_batch( task=tasks, ) + #def update_on_episode(self, episode_return: float, episode_task, env_id: int = None) -> None: + #""" + ##Update the curriculum with episode results from the environment. + #"" + #raise NotImplementedError( + #"PrioritizedLevelReplay does not support the episode updates. Use on_demand from the learner process." + #) + + + def update_task_progress(self, task: Any, success_prob: float, env_id: int = None) -> None: + """ + Update the curriculum with a task and its success probability upon + success or failure. + """ + assert env_id is not None, "env_id must be provided for PLR updates." + self._rollouts.insert_at_index( + env_id, + task=task, + ) # Update task sampler if env_id in self._rollouts.ready_buffers: self._update_sampler(env_id) diff --git a/syllabus/examples/training_scripts/cleanrl_procgen_plr.py b/syllabus/examples/training_scripts/cleanrl_procgen_plr.py index 8536cfed..536ecd87 100644 --- a/syllabus/examples/training_scripts/cleanrl_procgen_plr.py +++ b/syllabus/examples/training_scripts/cleanrl_procgen_plr.py @@ -211,7 +211,7 @@ def get_value(obs): ) # wandb.run.log_code("./syllabus/examples") - writer = SummaryWriter(os.path.join(args.logging_dir, "./runs/{run_name}")) + writer = SummaryWriter(os.path.join(args.logging_dir, f"./runs/{run_name}")) writer.add_text( "hyperparameters", "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), @@ -269,7 +269,8 @@ def get_value(obs): # env setup print("Creating env") - envs = gym.vector.AsyncVectorEnv( + envs = gym.vector.SyncVectorEnv( + #envs = gym.vector.AsyncVectorEnv( [ make_env( args.env_id, @@ -282,6 +283,24 @@ def get_value(obs): ) envs = wrap_vecenv(envs) + test_eval_envs = gym.vector.SyncVectorEnv( + #test_eval_envs = gym.vector.AsyncVectorEnv( + [ + make_env(args.env_id, args.seed + i, num_levels=0) + for i in range(args.num_eval_episodes) + ] + ) + test_eval_envs = wrap_vecenv(test_eval_envs) + + train_eval_envs = gym.vector.SyncVectorEnv( + #train_eval_envs = gym.vector.AsyncVectorEnv( + [ + make_env(args.env_id, args.seed + i, num_levels=200) + for i in range(args.num_eval_episodes) + ] + ) + train_eval_envs = wrap_vecenv(train_eval_envs) + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" print("Creating agent") agent = ProcgenAgent( @@ -344,6 +363,12 @@ def get_value(obs): writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step) if curriculum is not None: curriculum.log_metrics(writer, global_step) + + # track return for individual tasks + idx = info.index(item) + multiprocessing_sync_wrapper_envs = envs.venv.venv.envs # extract envs of class MultiProcessingSyncWrapper from envs of class VecNormalize, which has access to task id + episode_task = multiprocessing_sync_wrapper_envs[idx]._latest_task + curriculum.update_on_episode(item["episode"]["r"], item["episode"]["l"], episode_task, args.env_id) break # bootstrap value if not done From c470517d66d42e17ba951ce52efb6df4919c8a2c Mon Sep 17 00:00:00 2001 From: allisony9 Date: Tue, 2 Apr 2024 13:58:15 -0400 Subject: [PATCH 03/14] minor changes --- syllabus/examples/task_wrappers/procgen_task_wrapper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/syllabus/examples/task_wrappers/procgen_task_wrapper.py b/syllabus/examples/task_wrappers/procgen_task_wrapper.py index 2296fd58..99be70a0 100644 --- a/syllabus/examples/task_wrappers/procgen_task_wrapper.py +++ b/syllabus/examples/task_wrappers/procgen_task_wrapper.py @@ -39,7 +39,8 @@ def __init__(self, env: gym.Env, env_id, seed=0): self.observation_space = self.env.observation_space def seed(self, seed): - self.env.gym_env.unwrapped._venv.seed(int(seed), 0) + #self.env.gym_env.unwrapped._venv.seed(int(seed), 0) + pass def reset(self, new_task=None, **kwargs): """ From 4158368a5d70d21b1d7380f13d2fe86883f6fdc7 Mon Sep 17 00:00:00 2001 From: allisony9 Date: Tue, 2 Apr 2024 15:09:19 -0400 Subject: [PATCH 04/14] merge in upstream --- syllabus/examples/task_wrappers/procgen_task_wrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/syllabus/examples/task_wrappers/procgen_task_wrapper.py b/syllabus/examples/task_wrappers/procgen_task_wrapper.py index 99be70a0..497c9eca 100644 --- a/syllabus/examples/task_wrappers/procgen_task_wrapper.py +++ b/syllabus/examples/task_wrappers/procgen_task_wrapper.py @@ -39,8 +39,8 @@ def __init__(self, env: gym.Env, env_id, seed=0): self.observation_space = self.env.observation_space def seed(self, seed): - #self.env.gym_env.unwrapped._venv.seed(int(seed), 0) - pass + self.env.gym_env.unwrapped._venv.seed(int(seed), 0) + #pass def reset(self, new_task=None, **kwargs): """ From d04b6a7047567e5e18c8b840040647ce92b29e44 Mon Sep 17 00:00:00 2001 From: Xinchen Yang <91385265+xinchen-yang@users.noreply.github.com> Date: Thu, 4 Apr 2024 15:54:51 -0400 Subject: [PATCH 05/14] Update syllabus/core/stat_recorder.py Co-authored-by: Ryan Sullivan --- syllabus/core/stat_recorder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syllabus/core/stat_recorder.py b/syllabus/core/stat_recorder.py index ea5f41a6..01af2379 100644 --- a/syllabus/core/stat_recorder.py +++ b/syllabus/core/stat_recorder.py @@ -22,7 +22,7 @@ def __init__(self, task_space: TaskSpace): self.task_space = task_space assert isinstance(self.task_space, TaskSpace), f"task_space must be a TaskSpace object. Got {type(task_space)} instead." - assert isinstance(self.task_space.gym_space, Discrete), f"assume that task_space should be of type Discrete" + assert isinstance(self.task_space.gym_space, Discrete), f"Only Discrete task spaces are supported. Got {type(task_space.gym_space)}" self.tasks = self.task_space.get_tasks() self.num_tasks = self.task_space.num_tasks From d43decafa8a132146bb0c934fca6c5accbf6f1d9 Mon Sep 17 00:00:00 2001 From: allisony9 Date: Wed, 17 Apr 2024 12:51:08 -0400 Subject: [PATCH 06/14] implemented efficient reward normalization (running average & last N episodes); implemented the log_metrics function of the StatRecorder class for visualization on weights & biases --- syllabus/core/curriculum_base.py | 2 +- syllabus/core/curriculum_sync_wrapper.py | 1 + syllabus/core/stat_recorder.py | 130 ++++++++++++++++------- 3 files changed, 91 insertions(+), 42 deletions(-) diff --git a/syllabus/core/curriculum_base.py b/syllabus/core/curriculum_base.py index 5e75faff..4e7e4dc4 100644 --- a/syllabus/core/curriculum_base.py +++ b/syllabus/core/curriculum_base.py @@ -29,7 +29,7 @@ def __init__(self, task_space: TaskSpace, random_start_tasks: int = 0, task_name self.completed_tasks = 0 self.task_names = task_names self.n_updates = 0 - self.stat_recorder = StatRecorder(self.task_space) + self.stat_recorder = StatRecorder(self.task_space, self.task_names) if self.num_tasks == 0: warnings.warn("Task space is empty. This will cause errors during sampling if no tasks are added.") diff --git a/syllabus/core/curriculum_sync_wrapper.py b/syllabus/core/curriculum_sync_wrapper.py index 7b7a75f7..c1c2fb27 100644 --- a/syllabus/core/curriculum_sync_wrapper.py +++ b/syllabus/core/curriculum_sync_wrapper.py @@ -51,6 +51,7 @@ def update_on_step(self, task, step, reward, term, trunc): def log_metrics(self, writer, step=None): self.curriculum.log_metrics(writer, step=step) + self.curriculum.stat_recorder.log_metrics(writer, step=step) def update_on_step_batch(self, step_results): self.curriculum.update_on_step_batch(step_results) diff --git a/syllabus/core/stat_recorder.py b/syllabus/core/stat_recorder.py index 01af2379..3147a11c 100644 --- a/syllabus/core/stat_recorder.py +++ b/syllabus/core/stat_recorder.py @@ -1,25 +1,22 @@ +import warnings import numpy as np from syllabus.task_space import TaskSpace +from typing import Callable from gymnasium.spaces import Discrete #, MultiDiscrete? import json import os -def convert_numpy(obj): - if isinstance(obj, np.generic): - return obj.item() # Use .item() to convert numpy types to native Python types - raise TypeError - class StatRecorder: """ Individual stat tracking for each task. """ - def __init__(self, task_space: TaskSpace): + def __init__(self, task_space: TaskSpace, task_names: Callable = None, calc_past_N=None): """Initialize the StatRecorder""" - self.write_path = '/Users/allisonyang/Downloads' - self.task_space = task_space + self.task_names = task_names + self.calc_past_N = calc_past_N assert isinstance(self.task_space, TaskSpace), f"task_space must be a TaskSpace object. Got {type(task_space)} instead." assert isinstance(self.task_space.gym_space, Discrete), f"Only Discrete task spaces are supported. Got {type(task_space.gym_space)}" @@ -34,45 +31,96 @@ def record(self, episode_return: float, episode_length: int, episode_task, env_i """ Records the length and return of an episode for a given task. - :param task: Identifier for the task + :param episode_task: Identifier for the task :param episode_length: Length of the episode, i.e. the total number of steps taken during the episode :param episodic_return: Total return for the episode """ if episode_task in self.tasks: - self.records[episode_task].append({ - "r": episode_return, - "l": episode_length, - "env_id": env_id - }) - self.stats[episode_task]['mean_r'] = np.mean([record["r"] for record in self.records[episode_task]]) - self.stats[episode_task]['var_r'] = np.var([record["r"] for record in self.records[episode_task]]) - self.stats[episode_task]['mean_l'] = np.mean([record["l"] for record in self.records[episode_task]]) - self.stats[episode_task]['var_l'] = np.var([record["l"] for record in self.records[episode_task]]) + if self.calc_past_N: + self.records[episode_task].append({ + "r": episode_return, + "l": episode_length, + "env_id": env_id + }) + self.records[episode_task] = self.records[episode_task][-self.calc_past_N:] + self.stats[episode_task]['mean_r'] = np.mean([record["r"] for record in self.records[episode_task]]) + self.stats[episode_task]['var_r'] = np.var([record["r"] for record in self.records[episode_task]]) + self.stats[episode_task]['mean_l'] = np.mean([record["l"] for record in self.records[episode_task]]) + self.stats[episode_task]['var_l'] = np.var([record["l"] for record in self.records[episode_task]]) + # only save mean/variance the past N episodes + else: + # save the mean/variance of all the episodes + if 'mean_r' not in self.stats[episode_task].keys(): + # the first episode for a task + self.stats[episode_task]['mean_r'] = episode_return + self.stats[episode_task]['mean_r_squared'] = episode_return ** 2 + self.stats[episode_task]['var_r'] = 0 + self.stats[episode_task]['mean_l'] = episode_length + self.stats[episode_task]['mean_l_squared'] = episode_length ** 2 + self.stats[episode_task]['var_l'] = 0 + else: + N_past = len(self.records[episode_task]) + + self.stats[episode_task]['mean_r'] =round((self.stats[episode_task]['mean_r'] * N_past + episode_return) / (N_past + 1), 4) + self.stats[episode_task]['mean_r_squared'] = round((self.stats[episode_task]['mean_r_squared'] * N_past + episode_return ** 2) / (N_past + 1), 4) + self.stats[episode_task]['var_r'] = round(self.stats[episode_task]['mean_r_squared'] - self.stats[episode_task]['mean_r'] ** 2, 4) + + self.stats[episode_task]['mean_l'] = round((self.stats[episode_task]['mean_l'] * N_past + episode_length) / (N_past + 1), 4) + self.stats[episode_task]['mean_l_squared'] = round((self.stats[episode_task]['mean_l_squared'] * N_past + episode_length ** 2) / (N_past + 1), 4) + self.stats[episode_task]['var_l'] = round(self.stats[episode_task]['mean_l_squared'] - self.stats[episode_task]['mean_l'] ** 2, 4) else: raise ValueError("Unknown task") - + + def log_metrics(self, writer, step=None, log_full_dist=False): + """Log the statistics of the first 5 tasks to the provided tensorboard writer. + + :param writer: Tensorboard summary writer. """ - records = json.dumps(self.records, default=convert_numpy) - with open(os.path.join(self.write_path, 'records.json'), "w") as file: - file.write(records) - stats = json.dumps(self.stats, default=convert_numpy) - with open(os.path.join(self.write_path, 'stats.json'), "w") as file: - file.write(stats) + try: + import wandb + tasks_to_log = self.tasks + if len(self.tasks) > 5 and not log_full_dist: + warnings.warn("Only logging stats for 5 tasks.") + tasks_to_log = self.tasks[:5] + if self.task_names: + for idx in tasks_to_log: + if self.stats[idx]: + writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_return_mean", self.stats[idx]['mean_r'], step) + writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_return_var", self.stats[idx]['var_r'], step) + writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_length_mean", self.stats[idx]['mean_l'], step) + writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_length_var", self.stats[idx]['var_l'], step) + else: + writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_return_mean", 0, step) + writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_return_var", 0, step) + writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_length_mean", 0, step) + writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_length_var", 0, step) + else: + for idx in tasks_to_log: + if self.stats[idx]: + writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_mean", self.stats[idx]['mean_r'], step) + writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_var", self.stats[idx]['var_r'], step) + writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_mean", self.stats[idx]['mean_l'], step) + writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_var", self.stats[idx]['var_l'], step) + else: + writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_mean", 0, step) + writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_var", 0, step) + writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_mean", 0, step) + writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_var", 0, step) + except ImportError: + warnings.warn("Wandb is not installed. Skipping logging.") + except wandb.errors.Error: + # No need to crash over logging :) + warnings.warn("Failed to log curriculum stats to wandb.") + + def output_results(self, output_path): """ - - def get_task_return_avg(self, task): - """Returns the average episode length for a given task.""" - return np.mean([record["r"] for record in self.records[task]]) - - def get_task_return_sum(self, task): - """Returns the total return for a given task.""" - return sum([record["r"] for record in self.records[task]]) - - def get_task_return_variance(self, task): - """Returns the variance of returns for a given task.""" - return np.var([record["r"] for record in self.records[task]]) - - def get_task_return_std(self, task): - """Returns the standard deviation of returns for a given task.""" - return np.std([record["r"] for record in self.records[task]]) \ No newline at end of file + Write task-specific statistics to file. + """ + def convert_numpy(obj): + if isinstance(obj, np.generic): + return obj.item() # Use .item() to convert numpy types to native Python types + raise TypeError + stats = json.dumps(self.stats, default=convert_numpy) + with open(os.path.join(output_path, 'task_specific_stats.json'), "w") as file: + file.write(stats) \ No newline at end of file From cd9494bf26d6cca3372f8e0a1badd946dd144655 Mon Sep 17 00:00:00 2001 From: allisony9 Date: Fri, 19 Apr 2024 16:13:47 -0400 Subject: [PATCH 07/14] Simply StatRecorder code, use deque and defaultdict structure --- syllabus/core/curriculum_base.py | 2 +- syllabus/core/stat_recorder.py | 96 ++++++++++++-------------------- 2 files changed, 36 insertions(+), 62 deletions(-) diff --git a/syllabus/core/curriculum_base.py b/syllabus/core/curriculum_base.py index 4e7e4dc4..5e75faff 100644 --- a/syllabus/core/curriculum_base.py +++ b/syllabus/core/curriculum_base.py @@ -29,7 +29,7 @@ def __init__(self, task_space: TaskSpace, random_start_tasks: int = 0, task_name self.completed_tasks = 0 self.task_names = task_names self.n_updates = 0 - self.stat_recorder = StatRecorder(self.task_space, self.task_names) + self.stat_recorder = StatRecorder(self.task_space) if self.num_tasks == 0: warnings.warn("Task space is empty. This will cause errors during sampling if no tasks are added.") diff --git a/syllabus/core/stat_recorder.py b/syllabus/core/stat_recorder.py index 3147a11c..26b7a86c 100644 --- a/syllabus/core/stat_recorder.py +++ b/syllabus/core/stat_recorder.py @@ -1,35 +1,39 @@ +import os +import json import warnings import numpy as np from syllabus.task_space import TaskSpace from typing import Callable from gymnasium.spaces import Discrete #, MultiDiscrete? -import json -import os +from collections import deque, defaultdict class StatRecorder: """ Individual stat tracking for each task. """ - def __init__(self, task_space: TaskSpace, task_names: Callable = None, calc_past_N=None): + def __init__(self, task_space: TaskSpace, keep_last_N=10, calc_past_N=None): """Initialize the StatRecorder""" self.task_space = task_space - self.task_names = task_names self.calc_past_N = calc_past_N assert isinstance(self.task_space, TaskSpace), f"task_space must be a TaskSpace object. Got {type(task_space)} instead." assert isinstance(self.task_space.gym_space, Discrete), f"Only Discrete task spaces are supported. Got {type(task_space.gym_space)}" + if calc_past_N and calc_past_N > keep_last_N: + warnings.warn("The number of data points requested to calculate statistics exceeds the number of data points kept in memory. Will only use data points available in memory to calculate statistics instead.") self.tasks = self.task_space.get_tasks() self.num_tasks = self.task_space.num_tasks - self.records = {task: [] for task in self.tasks} - self.stats = {task: {} for task in self.tasks} + self.episode_returns = {task: deque(maxlen=keep_last_N) for task in self.tasks} + self.episode_lengths = {task: deque(maxlen=keep_last_N) for task in self.tasks} + self.env_ids = {task: deque(maxlen=keep_last_N) for task in self.tasks} + self.stats = {task: defaultdict(float) for task in self.tasks} def record(self, episode_return: float, episode_length: int, episode_task, env_id=None): """ - Records the length and return of an episode for a given task. + Record the length and return of an episode for a given task. :param episode_task: Identifier for the task :param episode_length: Length of the episode, i.e. the total number of steps taken during the episode @@ -38,37 +42,25 @@ def record(self, episode_return: float, episode_length: int, episode_task, env_i if episode_task in self.tasks: if self.calc_past_N: - self.records[episode_task].append({ - "r": episode_return, - "l": episode_length, - "env_id": env_id - }) - self.records[episode_task] = self.records[episode_task][-self.calc_past_N:] - self.stats[episode_task]['mean_r'] = np.mean([record["r"] for record in self.records[episode_task]]) - self.stats[episode_task]['var_r'] = np.var([record["r"] for record in self.records[episode_task]]) - self.stats[episode_task]['mean_l'] = np.mean([record["l"] for record in self.records[episode_task]]) - self.stats[episode_task]['var_l'] = np.var([record["l"] for record in self.records[episode_task]]) - # only save mean/variance the past N episodes + self.episode_returns[episode_task].append(episode_return) + self.episode_lengths[episode_task].append(episode_length) + self.env_ids[episode_task].append(env_id) + + self.stats[episode_task]['mean_r'] = np.mean(list(self.episode_returns[episode_task])[-self.calc_past_N:]) # I am not sure whether there is a more efficient way to slice to deque. I temperorily convert it to a list then slice it, which should cost O(n) + self.stats[episode_task]['var_r'] = np.var(list(self.episode_returns[episode_task])[-self.calc_past_N:]) + self.stats[episode_task]['mean_l'] = np.mean(list(self.episode_lengths[episode_task])[-self.calc_past_N:]) + self.stats[episode_task]['var_l'] = np.var(list(self.episode_lengths[episode_task])[-self.calc_past_N:]) else: # save the mean/variance of all the episodes - if 'mean_r' not in self.stats[episode_task].keys(): - # the first episode for a task - self.stats[episode_task]['mean_r'] = episode_return - self.stats[episode_task]['mean_r_squared'] = episode_return ** 2 - self.stats[episode_task]['var_r'] = 0 - self.stats[episode_task]['mean_l'] = episode_length - self.stats[episode_task]['mean_l_squared'] = episode_length ** 2 - self.stats[episode_task]['var_l'] = 0 - else: - N_past = len(self.records[episode_task]) - - self.stats[episode_task]['mean_r'] =round((self.stats[episode_task]['mean_r'] * N_past + episode_return) / (N_past + 1), 4) - self.stats[episode_task]['mean_r_squared'] = round((self.stats[episode_task]['mean_r_squared'] * N_past + episode_return ** 2) / (N_past + 1), 4) - self.stats[episode_task]['var_r'] = round(self.stats[episode_task]['mean_r_squared'] - self.stats[episode_task]['mean_r'] ** 2, 4) - - self.stats[episode_task]['mean_l'] = round((self.stats[episode_task]['mean_l'] * N_past + episode_length) / (N_past + 1), 4) - self.stats[episode_task]['mean_l_squared'] = round((self.stats[episode_task]['mean_l_squared'] * N_past + episode_length ** 2) / (N_past + 1), 4) - self.stats[episode_task]['var_l'] = round(self.stats[episode_task]['mean_l_squared'] - self.stats[episode_task]['mean_l'] ** 2, 4) + N_past = len(self.episode_returns[episode_task]) + + self.stats[episode_task]['mean_r'] = (self.stats[episode_task]['mean_r'] * N_past + episode_return) / (N_past + 1) + self.stats[episode_task]['mean_r_squared'] = (self.stats[episode_task]['mean_r_squared'] * N_past + episode_return ** 2) / (N_past + 1) + self.stats[episode_task]['var_r'] = self.stats[episode_task]['mean_r_squared'] - self.stats[episode_task]['mean_r'] ** 2 + + self.stats[episode_task]['mean_l'] = (self.stats[episode_task]['mean_l'] * N_past + episode_length) / (N_past + 1) + self.stats[episode_task]['mean_l_squared'] = (self.stats[episode_task]['mean_l_squared'] * N_past + episode_length ** 2) / (N_past + 1) + self.stats[episode_task]['var_l'] = self.stats[episode_task]['mean_l_squared'] - self.stats[episode_task]['mean_l'] ** 2 else: raise ValueError("Unknown task") @@ -83,37 +75,19 @@ def log_metrics(self, writer, step=None, log_full_dist=False): if len(self.tasks) > 5 and not log_full_dist: warnings.warn("Only logging stats for 5 tasks.") tasks_to_log = self.tasks[:5] - if self.task_names: - for idx in tasks_to_log: - if self.stats[idx]: - writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_return_mean", self.stats[idx]['mean_r'], step) - writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_return_var", self.stats[idx]['var_r'], step) - writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_length_mean", self.stats[idx]['mean_l'], step) - writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_length_var", self.stats[idx]['var_l'], step) - else: - writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_return_mean", 0, step) - writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_return_var", 0, step) - writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_length_mean", 0, step) - writer.add_scalar(f"stats_per_task/task_{self.task_space.task_name(idx)}_episode_length_var", 0, step) - else: - for idx in tasks_to_log: - if self.stats[idx]: - writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_mean", self.stats[idx]['mean_r'], step) - writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_var", self.stats[idx]['var_r'], step) - writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_mean", self.stats[idx]['mean_l'], step) - writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_var", self.stats[idx]['var_l'], step) - else: - writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_mean", 0, step) - writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_var", 0, step) - writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_mean", 0, step) - writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_var", 0, step) + for idx in tasks_to_log: + if self.stats[idx]: + writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_mean", self.stats[idx]['mean_r'], step) + writer.add_scalar(f"stats_per_task/task_{idx}_episode_return_var", self.stats[idx]['var_r'], step) + writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_mean", self.stats[idx]['mean_l'], step) + writer.add_scalar(f"stats_per_task/task_{idx}_episode_length_var", self.stats[idx]['var_l'], step) except ImportError: warnings.warn("Wandb is not installed. Skipping logging.") except wandb.errors.Error: # No need to crash over logging :) warnings.warn("Failed to log curriculum stats to wandb.") - def output_results(self, output_path): + def save_statistics(self, output_path): """ Write task-specific statistics to file. """ From a6de4f97dfa3bcc98bcb02093e5d38c7f3be9365 Mon Sep 17 00:00:00 2001 From: allisony9 Date: Fri, 19 Apr 2024 17:31:03 -0400 Subject: [PATCH 08/14] remove keep_past_n --- syllabus/core/stat_recorder.py | 37 ++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/syllabus/core/stat_recorder.py b/syllabus/core/stat_recorder.py index 26b7a86c..b25d7c34 100644 --- a/syllabus/core/stat_recorder.py +++ b/syllabus/core/stat_recorder.py @@ -12,23 +12,25 @@ class StatRecorder: Individual stat tracking for each task. """ - def __init__(self, task_space: TaskSpace, keep_last_N=10, calc_past_N=None): + def __init__(self, task_space: TaskSpace, calc_past_n=None): """Initialize the StatRecorder""" self.task_space = task_space - self.calc_past_N = calc_past_N + self.calc_past_n = calc_past_n assert isinstance(self.task_space, TaskSpace), f"task_space must be a TaskSpace object. Got {type(task_space)} instead." assert isinstance(self.task_space.gym_space, Discrete), f"Only Discrete task spaces are supported. Got {type(task_space.gym_space)}" - if calc_past_N and calc_past_N > keep_last_N: - warnings.warn("The number of data points requested to calculate statistics exceeds the number of data points kept in memory. Will only use data points available in memory to calculate statistics instead.") self.tasks = self.task_space.get_tasks() self.num_tasks = self.task_space.num_tasks - self.episode_returns = {task: deque(maxlen=keep_last_N) for task in self.tasks} - self.episode_lengths = {task: deque(maxlen=keep_last_N) for task in self.tasks} - self.env_ids = {task: deque(maxlen=keep_last_N) for task in self.tasks} + if self.calc_past_n is not None: + self.episode_returns = {task: deque(maxlen=calc_past_n) for task in self.tasks} + self.episode_lengths = {task: deque(maxlen=calc_past_n) for task in self.tasks} + self.env_ids = {task: deque(maxlen=calc_past_n) for task in self.tasks} + else: + self.num_past_episodes = {task: 0 for task in self.tasks} + self.stats = {task: defaultdict(float) for task in self.tasks} def record(self, episode_return: float, episode_length: int, episode_task, env_id=None): @@ -41,25 +43,26 @@ def record(self, episode_return: float, episode_length: int, episode_task, env_i """ if episode_task in self.tasks: - if self.calc_past_N: + if self.calc_past_n is not None: self.episode_returns[episode_task].append(episode_return) self.episode_lengths[episode_task].append(episode_length) self.env_ids[episode_task].append(env_id) - self.stats[episode_task]['mean_r'] = np.mean(list(self.episode_returns[episode_task])[-self.calc_past_N:]) # I am not sure whether there is a more efficient way to slice to deque. I temperorily convert it to a list then slice it, which should cost O(n) - self.stats[episode_task]['var_r'] = np.var(list(self.episode_returns[episode_task])[-self.calc_past_N:]) - self.stats[episode_task]['mean_l'] = np.mean(list(self.episode_lengths[episode_task])[-self.calc_past_N:]) - self.stats[episode_task]['var_l'] = np.var(list(self.episode_lengths[episode_task])[-self.calc_past_N:]) + self.stats[episode_task]['mean_r'] = np.mean(self.episode_returns[episode_task]) + self.stats[episode_task]['var_r'] = np.var(self.episode_returns[episode_task]) + self.stats[episode_task]['mean_l'] = np.mean(self.episode_lengths[episode_task]) + self.stats[episode_task]['var_l'] = np.var(self.episode_lengths[episode_task]) else: # save the mean/variance of all the episodes - N_past = len(self.episode_returns[episode_task]) + n_past = self.num_past_episodes[episode_task] + self.num_past_episodes[episode_task] += 1 - self.stats[episode_task]['mean_r'] = (self.stats[episode_task]['mean_r'] * N_past + episode_return) / (N_past + 1) - self.stats[episode_task]['mean_r_squared'] = (self.stats[episode_task]['mean_r_squared'] * N_past + episode_return ** 2) / (N_past + 1) + self.stats[episode_task]['mean_r'] = (self.stats[episode_task]['mean_r'] * n_past + episode_return) / (n_past + 1) + self.stats[episode_task]['mean_r_squared'] = (self.stats[episode_task]['mean_r_squared'] * n_past + episode_return ** 2) / (n_past + 1) self.stats[episode_task]['var_r'] = self.stats[episode_task]['mean_r_squared'] - self.stats[episode_task]['mean_r'] ** 2 - self.stats[episode_task]['mean_l'] = (self.stats[episode_task]['mean_l'] * N_past + episode_length) / (N_past + 1) - self.stats[episode_task]['mean_l_squared'] = (self.stats[episode_task]['mean_l_squared'] * N_past + episode_length ** 2) / (N_past + 1) + self.stats[episode_task]['mean_l'] = (self.stats[episode_task]['mean_l'] * n_past + episode_length) / (n_past + 1) + self.stats[episode_task]['mean_l_squared'] = (self.stats[episode_task]['mean_l_squared'] * n_past + episode_length ** 2) / (n_past + 1) self.stats[episode_task]['var_l'] = self.stats[episode_task]['mean_l_squared'] - self.stats[episode_task]['mean_l'] ** 2 else: raise ValueError("Unknown task") From 1e6938d4a62e0042e234229a84c7fa7dadd625aa Mon Sep 17 00:00:00 2001 From: allisony9 Date: Tue, 23 Apr 2024 15:34:43 -0400 Subject: [PATCH 09/14] implemented the normalize function of stat_recorder, write test cases for the StatRecorder class --- syllabus/core/__init__.py | 1 + syllabus/core/stat_recorder.py | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/syllabus/core/__init__.py b/syllabus/core/__init__.py index 07322779..08f7a709 100644 --- a/syllabus/core/__init__.py +++ b/syllabus/core/__init__.py @@ -13,3 +13,4 @@ from .environment_sync_wrapper import MultiProcessingSyncWrapper, RaySyncWrapper # , PettingZooMultiProcessingSyncWrapper from .multivariate_curriculum_wrapper import MultitaskWrapper +from .stat_recorder import StatRecorder diff --git a/syllabus/core/stat_recorder.py b/syllabus/core/stat_recorder.py index b25d7c34..d4c025b9 100644 --- a/syllabus/core/stat_recorder.py +++ b/syllabus/core/stat_recorder.py @@ -3,13 +3,12 @@ import warnings import numpy as np from syllabus.task_space import TaskSpace -from typing import Callable -from gymnasium.spaces import Discrete #, MultiDiscrete? +from gymnasium.spaces import Discrete from collections import deque, defaultdict class StatRecorder: """ - Individual stat tracking for each task. + Individual statistics tracking for each task. """ def __init__(self, task_space: TaskSpace, calc_past_n=None): @@ -37,9 +36,9 @@ def record(self, episode_return: float, episode_length: int, episode_task, env_i """ Record the length and return of an episode for a given task. - :param episode_task: Identifier for the task :param episode_length: Length of the episode, i.e. the total number of steps taken during the episode :param episodic_return: Total return for the episode + :param episode_task: Identifier for the task """ if episode_task in self.tasks: @@ -53,7 +52,6 @@ def record(self, episode_return: float, episode_length: int, episode_task, env_i self.stats[episode_task]['mean_l'] = np.mean(self.episode_lengths[episode_task]) self.stats[episode_task]['var_l'] = np.var(self.episode_lengths[episode_task]) else: - # save the mean/variance of all the episodes n_past = self.num_past_episodes[episode_task] self.num_past_episodes[episode_task] += 1 @@ -90,6 +88,18 @@ def log_metrics(self, writer, step=None, log_full_dist=False): # No need to crash over logging :) warnings.warn("Failed to log curriculum stats to wandb.") + def normalize(self, reward, task): + """ + Normalize reward by task. + """ + task_stats = self.stats[task] + reward_mean = task_stats['mean_r'] + reward_std = np.sqrt(task_stats['var_r']) + normalized_reward = deque(maxlen=reward.maxlen) + for r in reward: + normalized_reward.append((r - reward_mean) / max(0.01, reward_std)) + return normalized_reward + def save_statistics(self, output_path): """ Write task-specific statistics to file. From c29887e82ec335e4a02c1feab37c97060cecf923 Mon Sep 17 00:00:00 2001 From: allisony9 Date: Tue, 23 Apr 2024 15:48:50 -0400 Subject: [PATCH 10/14] syllabus/tests/__init__ --- tests/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..f600ce6b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,4 @@ +from .sync_test_curriculum import SyncTestCurriculum +from .sync_test_env import SyncTestEnv +from .utils import * +from .determinism import test_determinism From a00f90b58b8697a574f5daa5ee7fc9eecd8de15a Mon Sep 17 00:00:00 2001 From: allisony9 Date: Tue, 23 Apr 2024 15:53:26 -0400 Subject: [PATCH 11/14] test cases for the StatRecorder class --- tests/test_stat_recorder.py | 89 +++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 tests/test_stat_recorder.py diff --git a/tests/test_stat_recorder.py b/tests/test_stat_recorder.py new file mode 100644 index 00000000..89144e46 --- /dev/null +++ b/tests/test_stat_recorder.py @@ -0,0 +1,89 @@ +import gymnasium as gym +import numpy as np +from collections import deque +from syllabus.task_space import TaskSpace +from syllabus.core import StatRecorder + +def test(stat_recorder, simulated_eps_info, expected_episode_reward_mean_by_task, expected_episode_reward_std_by_task, expected_episode_length_mean_by_task, expected_episode_length_std_by_task, expected_normalized_reward): + num_passed = 0 + num_failed = 0 + for i in range(len(simulated_eps_info)): + task_id, eps_reward, eps_length = simulated_eps_info[i] + stat_recorder.record(eps_reward, eps_length, task_id) + rewards_for_task = deque([x[1] for x in simulated_eps_info[:i+1] if x[0]==task_id], maxlen=10) + normalized_reward_for_task = stat_recorder.normalize(rewards_for_task, task_id) + try: + assert abs(stat_recorder.stats[task_id]['mean_r'] - expected_episode_reward_mean_by_task[i]) < 1e-7 + assert abs(np.sqrt(stat_recorder.stats[task_id]['var_r']) - expected_episode_reward_std_by_task[i]) < 1e-7 + assert abs(stat_recorder.stats[task_id]['mean_l'] - expected_episode_length_mean_by_task[i]) < 1e-7 + assert abs(np.sqrt(stat_recorder.stats[task_id]['var_l']) - expected_episode_length_std_by_task[i]) < 1e-7 + for j in range(len(normalized_reward_for_task)): + assert abs(normalized_reward_for_task[j] - expected_normalized_reward[i][j]) < 1e-7 + print(f"Test case {i} PASSED.") + num_passed += 1 + except AssertionError: + print(f"Test case {i} FAILED.") + num_failed += 1 + print(f"{len(simulated_eps_info)} tests total, {num_passed} tests passed, {num_failed} test failed. Pass rate: {num_passed / len(simulated_eps_info) * 100}%.\n") + +def main(): + """ + simulated_eps_info: A list of tuples simulateing episodic infomation. + Each tuple satisfies the format (task_id, episode_return, episode_length). + """ + simulated_eps_info = [(0, 5.0, 48), (1, 1.0, 75), (2, 2.0, 36), (2, 4.0, 65), (0, 1.0, 54), (1, 3.0, 82), + (0, 2.0, 39), (2, 3.0, 80), (0, 4.0, 57), (1, 0.0, 94), (1, 2.0, 64), (0, 0.0, 45), + (2, 1.0, 86), (0, 2.0, 68), (1, 2.0, 92), (2, 1.0, 71), (0, 3.0, 32), (2, 1.0, 47)] + task_space = TaskSpace(gym.spaces.Discrete(3), list(np.arange(0, 3))) + + """ + Testing StatRecorder by calculating running average. + """ + print("Testing StatRecorder by calculating running average: ") + expected_episode_reward_mean_by_task = [5.0, 1.0, 2.0, 3.0, 3.0, 2.0, + 2.6666667, 3.0, 3.0, 1.3333333, 1.5, 2.4, + 2.5, 2.3333333, 1.6, 2.2, 2.4285714, 2.0 + ] + expected_episode_reward_std_by_task = [0.0, 0.0, 0.0, 1.0, 2.0, 1.0, + 1.6996732, 0.8164966, 1.5811388, 1.2472191, 1.118034, 1.8547237, + 1.118034, 1.6996732, 1.0198039, 1.1661904, 1.5907898, 1.1547005] + expected_episode_length_mean_by_task = [48, 75, 36, 50.5, 51, 78.5, + 47, 60.3333333, 49.5, 83.6666667, 78.75, 48.6, + 66.75, 51.8333333, 81.4, 67.6, 49, 64.1666667] + expected_episode_length_std_by_task = [0.0, 0.0, 0.0, 14.5, 3.0, 3.5, + 6.164414, 18.2635034, 6.8738635, 7.8457349, 10.8943793, 6.406247, 19.3309984, 9.2990442, 11.0923397, 17.3735431, 11.0582871, 17.620222] + expected_normalized_reward = [deque([0.0]), deque([0.0]), deque([0.0]), deque([-1.0, 1.0]), deque([1.0, -1.0]), deque([-1.0, 1.0]), + deque([1.3728129, -0.9805807, -0.3922323]), deque([-1.2247449, 1.2247449, 0.0]), deque([1.2649111, -1.2649111, -0.6324555, 0.6324555]), deque([-0.2672612, 1.3363062, -1.069045]), deque([-0.4472136, 1.3416408, -1.3416408, 0.4472136]), deque([1.4018261, -0.7548294, -0.2156655, 0.8626622, -1.2939933]), + deque([-0.4472136, 1.3416408, 0.4472136, -1.3416408]), deque([1.5689291, -0.7844645, -0.1961161, 0.9805807, -1.3728129, -0.1961161]), deque([-0.5883484, 1.3728129, -1.5689291, 0.3922323, 0.3922323]), deque([-0.1714986, 1.5434873, 0.6859943, -1.0289915, -1.0289915]), deque([1.6164477, -0.8980265, -0.269408, 0.9878292, -1.5266451, -0.269408, 0.3592106]), deque([0.0, 1.7320508, 0.8660254, -0.8660254, -0.8660254, -0.8660254]) + ] + + stat_recorder = StatRecorder(task_space) + test(stat_recorder, simulated_eps_info, expected_episode_reward_mean_by_task, expected_episode_reward_std_by_task, expected_episode_length_mean_by_task, expected_episode_length_std_by_task, expected_normalized_reward) + + """ + Testing StatRecorder by calculating based on last n episodes. + """ + calc_past_n = 3 + print(f"Testing StatRecorder by calculating based on last {calc_past_n} episodes: ") + expected_episode_reward_mean_by_task = [5.0, 1.0, 2.0, 3.0, 3.0, 2.0, + 2.6666667, 3.0, 2.3333333, 1.3333333, 1.6666667, 2.0, + 2.6666667, 2.0, 1.3333333, 1.6666667, 1.6666667, 1.0] + expected_episode_reward_std_by_task = [0.0, 0.0, 0.0, 1.0, 2.0, 1.0, + 1.6996732, 0.8164966, 1.2472191, 1.2472191, 1.2472191, 1.6329932, + 1.2472191, 1.6329932, 0.942809, 0.942809, 1.2472191, 0.0] + expected_episode_length_mean_by_task = [48, 75, 36, 50.5, 51, 78.5, + 47, 60.3333333, 50, 83.6666667, 80, 47, + 77, 56.6666667, 83.3333333, 79, 48.3333333, 68] + expected_episode_length_std_by_task = [0.0, 0.0, 0.0, 14.5, 3.0, 3.5, + 6.164414, 18.2635034, 7.8740079, 7.8457349, 12.328828, 7.4833148, + 8.8317609, 9.3926685, 13.6950924, 6.164414, 14.8847424, 16.0623784] + expected_normalized_reward = [deque([0.0]), deque([0.0]), deque([0.0]), deque([-1.0, 1.0]), deque([1.0, -1.0]), deque([-1.0, 1.0]), + deque([1.3728129, -0.9805807, -0.3922323]), deque([-1.2247449, 1.2247449, 0.0]), deque([2.1380899, -1.069045, -0.2672612, 1.3363062]), deque([-0.2672612, 1.3363062, -1.069045]), deque([-0.5345225, 1.069045, -1.3363062, 0.2672612]), deque([1.8371173, -0.6123724, 0.0, 1.2247449, -1.2247449]), + deque([-0.5345225, 1.069045, 0.2672612, -1.3363062]), deque([1.8371173, -0.6123724, 0.0, 1.2247449, -1.2247449, 0.0]), deque([-0.3535534, 1.767767, -1.4142136, 0.7071068, 0.7071068]), deque([0.3535534, 2.4748737, 1.4142136, -0.7071068, -0.7071068]), deque([2.6726124, -0.5345225, 0.2672612, 1.8708287, -1.3363062, 0.2672612, 1.069045]), deque([100.0, 300.0, 200.0, 0.0, 0.0, 0.0]) + ] + + stat_recorder = StatRecorder(task_space, calc_past_n) + test(stat_recorder, simulated_eps_info, expected_episode_reward_mean_by_task, expected_episode_reward_std_by_task, expected_episode_length_mean_by_task, expected_episode_length_std_by_task, expected_normalized_reward) + +if __name__ == "__main__": + main() \ No newline at end of file From 4441ba3a597cf4137dae2043564038d7053cd0cf Mon Sep 17 00:00:00 2001 From: allisony9 Date: Tue, 23 Apr 2024 23:32:57 -0400 Subject: [PATCH 12/14] merge in upstream --- core/__init__.py | 15 + core/curriculum_base.py | 223 +++ core/curriculum_sync_wrapper.py | 323 +++++ core/environment_sync_wrapper.py | 348 +++++ core/multivariate_curriculum_wrapper.py | 82 ++ core/task_interface/__init__.py | 5 + core/task_interface/environment_task_env.py | 89 ++ core/task_interface/reinit_task_wrapper.py | 119 ++ core/task_interface/subclass_task_wrapper.py | 102 ++ core/task_interface/task_wrapper.py | 99 ++ core/utils.py | 32 + curricula/__init__.py | 11 + curricula/annealing_box.py | 56 + curricula/domain_randomization.py | 21 + curricula/learning_progress.py | 198 +++ curricula/noop.py | 62 + curricula/plr/__init__.py | 3 + curricula/plr/central_plr_wrapper.py | 239 +++ curricula/plr/plr_wrapper.py | 295 ++++ curricula/plr/task_sampler.py | 354 +++++ curricula/sequential.py | 207 +++ curricula/simple_box.py | 63 + examples/__init__.py | 0 examples/experimental/README | 1 + examples/experimental/cleanrl_cartpole.py | 326 +++++ examples/experimental/cleanrl_minigrid_plr.py | 353 +++++ examples/experimental/cleanrl_minihack_plr.py | 358 +++++ .../cleanrl_pettingzoo_pistonball_plr.py | 334 +++++ examples/experimental/dormant_neurons.py | 764 ++++++++++ examples/experimental/rllib_cartpole.py | 40 + examples/experimental/sb3_procgen_plr.py | 116 ++ examples/experimental/torchbeast_nethack.py | 1280 +++++++++++++++++ examples/task_wrappers/__init__.py | 23 + .../task_wrappers/cartpole_task_wrapper.py | 24 + .../task_wrappers/minigrid_task_wrapper.py | 88 ++ .../task_wrappers/minihack_task_wrapper.py | 30 + examples/task_wrappers/nethack_wrappers.py | 493 +++++++ .../task_wrappers/pistonball_task_wrapper.py | 35 + .../task_wrappers/procgen_task_wrapper.py | 86 ++ .../cleanrl_procgen_centralplr.py | 536 +++++++ .../training_scripts/cleanrl_procgen_plr.py | 528 +++++++ examples/utils/__init__.py | 0 examples/utils/vecenv.py | 308 ++++ examples/utils/vtrace.py | 138 ++ task_space/__init__.py | 1 + task_space/task_space.py | 234 +++ task_space/test_task_space.py | 182 +++ tests/determinism.py | 154 ++ tests/sync_test_curriculum.py | 51 + tests/sync_test_env.py | 33 + tests/utils.py | 278 ++++ 51 files changed, 9740 insertions(+) create mode 100644 core/__init__.py create mode 100644 core/curriculum_base.py create mode 100644 core/curriculum_sync_wrapper.py create mode 100644 core/environment_sync_wrapper.py create mode 100644 core/multivariate_curriculum_wrapper.py create mode 100644 core/task_interface/__init__.py create mode 100644 core/task_interface/environment_task_env.py create mode 100644 core/task_interface/reinit_task_wrapper.py create mode 100644 core/task_interface/subclass_task_wrapper.py create mode 100644 core/task_interface/task_wrapper.py create mode 100644 core/utils.py create mode 100644 curricula/__init__.py create mode 100644 curricula/annealing_box.py create mode 100644 curricula/domain_randomization.py create mode 100644 curricula/learning_progress.py create mode 100644 curricula/noop.py create mode 100644 curricula/plr/__init__.py create mode 100644 curricula/plr/central_plr_wrapper.py create mode 100644 curricula/plr/plr_wrapper.py create mode 100644 curricula/plr/task_sampler.py create mode 100644 curricula/sequential.py create mode 100644 curricula/simple_box.py create mode 100644 examples/__init__.py create mode 100644 examples/experimental/README create mode 100644 examples/experimental/cleanrl_cartpole.py create mode 100644 examples/experimental/cleanrl_minigrid_plr.py create mode 100644 examples/experimental/cleanrl_minihack_plr.py create mode 100644 examples/experimental/cleanrl_pettingzoo_pistonball_plr.py create mode 100644 examples/experimental/dormant_neurons.py create mode 100644 examples/experimental/rllib_cartpole.py create mode 100644 examples/experimental/sb3_procgen_plr.py create mode 100644 examples/experimental/torchbeast_nethack.py create mode 100644 examples/task_wrappers/__init__.py create mode 100644 examples/task_wrappers/cartpole_task_wrapper.py create mode 100644 examples/task_wrappers/minigrid_task_wrapper.py create mode 100644 examples/task_wrappers/minihack_task_wrapper.py create mode 100644 examples/task_wrappers/nethack_wrappers.py create mode 100644 examples/task_wrappers/pistonball_task_wrapper.py create mode 100644 examples/task_wrappers/procgen_task_wrapper.py create mode 100644 examples/training_scripts/cleanrl_procgen_centralplr.py create mode 100644 examples/training_scripts/cleanrl_procgen_plr.py create mode 100644 examples/utils/__init__.py create mode 100644 examples/utils/vecenv.py create mode 100644 examples/utils/vtrace.py create mode 100644 task_space/__init__.py create mode 100644 task_space/task_space.py create mode 100644 task_space/test_task_space.py create mode 100644 tests/determinism.py create mode 100644 tests/sync_test_curriculum.py create mode 100644 tests/sync_test_env.py create mode 100644 tests/utils.py diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 00000000..07322779 --- /dev/null +++ b/core/__init__.py @@ -0,0 +1,15 @@ +# Environment Code +from .task_interface import TaskWrapper, SubclassTaskWrapper, ReinitTaskWrapper, TaskEnv # , PettingZooTaskWrapper, PettingZooTaskEnv + +# Curriculum Code +from .utils import decorate_all_functions, UsageError, enumerate_axes +from .curriculum_base import Curriculum +from .curriculum_sync_wrapper import (CurriculumWrapper, + MultiProcessingCurriculumWrapper, + MultiProcessingComponents, + RayCurriculumWrapper, + make_multiprocessing_curriculum, + make_ray_curriculum) + +from .environment_sync_wrapper import MultiProcessingSyncWrapper, RaySyncWrapper # , PettingZooMultiProcessingSyncWrapper +from .multivariate_curriculum_wrapper import MultitaskWrapper diff --git a/core/curriculum_base.py b/core/curriculum_base.py new file mode 100644 index 00000000..4ca9aeb0 --- /dev/null +++ b/core/curriculum_base.py @@ -0,0 +1,223 @@ +import typing +import warnings +from typing import Any, Callable, List, Tuple, Union + +import numpy as np +from gymnasium.spaces import Dict + +from syllabus.task_space import TaskSpace + + +# TODO: Move non-generic logic to Uniform class. Allow subclasses to call super for generic error handling +class Curriculum: + """Base class and API for defining curricula to interface with Gym environments. + """ + + def __init__(self, task_space: TaskSpace, random_start_tasks: int = 0, task_names: Callable = None) -> None: + """Initialize the base Curriculum + + :param task_space: the environment's task space from which new tasks are sampled + TODO: Implement this in a way that works with any curriculum, maybe as a wrapper + :param random_start_tasks: Number of uniform random tasks to sample before using the algorithm's sample method, defaults to 0 + TODO: Use task space for this + :param task_names: Names of the tasks in the task space, defaults to None + """ + assert isinstance(task_space, TaskSpace), f"task_space must be a TaskSpace object. Got {type(task_space)} instead." + self.task_space = task_space + self.random_start_tasks = random_start_tasks + self.completed_tasks = 0 + self.task_names = task_names + self.n_updates = 0 + + if self.num_tasks == 0: + warnings.warn("Task space is empty. This will cause errors during sampling if no tasks are added.") + + @property + def requires_step_updates(self) -> bool: + """Returns whether the curriculum requires step updates from the environment. + + :return: True if the curriculum requires step updates, False otherwise + """ + return self.__class__.REQUIRES_STEP_UPDATES + + @property + def requires_episode_updates(self) -> bool: + """Returns whether the curriculum requires episode updates from the environment. + + :return: True if the curriculum requires episode updates, False otherwise + """ + return self.__class__.REQUIRES_EPISODE_UPDATES + + @property + def num_tasks(self) -> int: + """Counts the number of tasks in the task space. + + :return: Returns the number of tasks in the task space if it is countable, TODO: -1 otherwise + """ + return self.task_space.num_tasks + + @property + def tasks(self) -> List[tuple]: + """List all of the tasks in the task space. + + :return: List of tasks if task space is enumerable, TODO: empty list otherwise? + """ + return list(self.task_space.tasks) + + def add_task(self, task: typing.Any) -> None: + # TODO + raise NotImplementedError("This curriculum does not support adding tasks after initialization.") + + def update_task_progress(self, task: typing.Any, progress: Tuple[float, bool], env_id: int = None) -> None: + """Update the curriculum with a task and its progress. + + :param task: Task for which progress is being updated. + :param progress: Progress toward completion or success rate of the given task. 1.0 or True typically indicates a complete task. + """ + self.completed_tasks += 1 + + def update_on_step(self, task: typing.Any, obs: typing.Any, rew: float, term: bool, trunc: bool, info: dict, env_id: int = None) -> None: + """ Update the curriculum with the current step results from the environment. + + :param obs: Observation from teh environment + :param rew: Reward from the environment + :param term: True if the episode ended on this step, False otherwise + :param trunc: True if the episode was truncated on this step, False otherwise + :param info: Extra information from the environment + :raises NotImplementedError: + """ + raise NotImplementedError("This curriculum does not require step updates. Set update_on_step for the environment sync wrapper to False to improve performance and prevent this error.") + + def update_on_step_batch(self, step_results: List[typing.Tuple[Any, Any, int, int, int, int]], env_id: int = None) -> None: + """Update the curriculum with a batch of step results from the environment. + + This method can be overridden to provide a more efficient implementation. It is used + as a convenience function and to optimize the multiprocessing message passing throughput. + + :param step_results: List of step results + """ + tasks, obs, rews, terms, truncs, infos = tuple(step_results) + for i in range(len(obs)): + self.update_on_step(tasks[i], obs[i], rews[i], terms[i], truncs[i], infos[i], env_id=env_id) + + def update_on_episode(self, episode_return: float, episode_length: int, episode_task: Any, env_id: int = None) -> None: + """Update the curriculum with episode results from the environment. + + :param episode_return: Episodic return + :param trajectory: trajectory of (s, a, r, s, ...), defaults to None + :raises NotImplementedError: + """ + # TODO: Add update_on_episode option similar to update-on_step + pass + + def update_on_demand(self, metrics: Dict): + """Update the curriculum with arbitrary inputs. + + + :param metrics: Arbitrary dictionary of information. Can be used to provide gradient/error based + updates from the training process. + :raises NotImplementedError: + """ + raise NotImplementedError + + # TODO: Move to curriculum sync wrapper? + def update(self, update_data: typing.Dict[str, tuple]): + """Update the curriculum with the specified update type. + TODO: Change method header to not use dictionary, use enums? + + :param update_data: Dictionary + :type update_data: Dictionary with "update_type" key which maps to one of ["step", "step_batch", "episode", "on_demand", "task_progress", "add_task", "noop"] and "args" with a tuple of the appropriate arguments for the given "update_type". + :raises NotImplementedError: + """ + + update_type = update_data["update_type"] + args = update_data["metrics"] + env_id = update_data["env_id"] if "env_id" in update_data else None + + if update_type == "step": + self.update_on_step(*args, env_id=env_id) + elif update_type == "step_batch": + self.update_on_step_batch(*args, env_id=env_id) + elif update_type == "episode": + self.update_on_episode(*args, env_id=env_id) + elif update_type == "on_demand": + # Directly pass metrics without expanding + self.update_on_demand(args) + elif update_type == "task_progress": + self.update_task_progress(*args, env_id=env_id) + elif update_type == "task_progress_batch": + tasks, progresses = args + for task, progress in zip(tasks, progresses): + self.update_task_progress(task, progress, env_id=env_id) + elif update_type == "add_task": + self.add_task(args) + elif update_type == "noop": + # Used to request tasks from the synchronization layer + pass + else: + raise NotImplementedError(f"Update type {update_type} not implemented.") + self.n_updates += 1 + + def update_batch(self, update_data: List[Dict]): + """Update the curriculum with batch of updates. + + :param update_data: List of updates or potentially varying types + """ + for update in update_data: + self.update(update) + + def _sample_distribution(self) -> List[float]: + """Returns a sample distribution over the task space. + + Any curriculum that maintains a true probability distribution should implement this method to retrieve it. + """ + raise NotImplementedError + + def _should_use_startup_sampling(self) -> bool: + return self.random_start_tasks > 0 and self.completed_tasks < self.random_start_tasks + + def _startup_sample(self) -> List: + task_dist = [0.0 / self.num_tasks for _ in range(self.num_tasks)] + task_dist[0] = 1.0 + return task_dist + + def sample(self, k: int = 1) -> Union[List, Any]: + """Sample k tasks from the curriculum. + + :param k: Number of tasks to sample, defaults to 1 + :return: Either returns a single task if k=1, or a list of k tasks + """ + assert self.num_tasks > 0, "Task space is empty. Please add tasks to the curriculum before sampling." + + if self._should_use_startup_sampling(): + return self._startup_sample() + + # Use list of indices because np.choice does not play nice with tuple tasks + # tasks = self.tasks + n_tasks = self.num_tasks + task_dist = self._sample_distribution() + task_idx = np.random.choice(list(range(n_tasks)), size=k, p=task_dist) + return task_idx + + def log_metrics(self, writer, step=None, log_full_dist=False): + """Log the task distribution to the provided tensorboard writer. + + :param writer: Tensorboard summary writer. + """ + try: + import wandb + task_dist = self._sample_distribution() + if len(task_dist) > 10 and not log_full_dist: + warnings.warn("Only logging stats for 10 tasks.") + task_dist = task_dist[:10] + if self.task_names: + for idx, prob in enumerate(task_dist): + writer.add_scalar(f"curriculum/task_{self.task_space.task_name(idx)}_prob", prob, step) + else: + for idx, prob in enumerate(task_dist): + writer.add_scalar(f"curriculum/task_{idx}_prob", prob, step) + except ImportError: + warnings.warn("Wandb is not installed. Skipping logging.") + except wandb.errors.Error: + # No need to crash over logging :) + warnings.warn("Failed to log curriculum stats to wandb.") diff --git a/core/curriculum_sync_wrapper.py b/core/curriculum_sync_wrapper.py new file mode 100644 index 00000000..f9866438 --- /dev/null +++ b/core/curriculum_sync_wrapper.py @@ -0,0 +1,323 @@ +import threading +import time +from functools import wraps +from multiprocessing.shared_memory import ShareableList +from typing import List, Tuple + +import ray +from torch.multiprocessing import Lock, SimpleQueue + +from syllabus.core import Curriculum, decorate_all_functions + + +class CurriculumWrapper: + """Wrapper class for adding multiprocessing synchronization to a curriculum. + """ + def __init__(self, curriculum: Curriculum) -> None: + self.curriculum = curriculum + self.task_space = curriculum.task_space + self.unwrapped = curriculum + + @property + def num_tasks(self): + return self.task_space.num_tasks + + def count_tasks(self, task_space=None): + return self.task_space.count_tasks(gym_space=task_space) + + @property + def tasks(self): + return self.task_space.tasks + + @property + def requires_step_updates(self): + return self.curriculum.requires_step_updates + + @property + def requires_episode_updates(self): + return self.curriculum.requires_episode_updates + + def get_tasks(self, task_space=None): + return self.task_space.get_tasks(gym_space=task_space) + + def sample(self, k=1): + return self.curriculum.sample(k=k) + + def update_task_progress(self, task, progress): + self.curriculum.update_task_progress(task, progress) + + def update_on_step(self, task, step, reward, term, trunc): + self.curriculum.update_on_step(task, step, reward, term, trunc) + + def log_metrics(self, writer, step=None): + self.curriculum.log_metrics(writer, step=step) + + def update_on_step_batch(self, step_results): + self.curriculum.update_on_step_batch(step_results) + + def update(self, metrics): + self.curriculum.update(metrics) + + def update_batch(self, metrics): + self.curriculum.update_batch(metrics) + + def add_task(self, task): + self.curriculum.add_task(task) + + +class MultiProcessingComponents: + def __init__(self, task_queue, update_queue): + self.task_queue = task_queue + self.update_queue = update_queue + self._instance_lock = Lock() + self._env_count = ShareableList([0]) + self._task_count = ShareableList([0]) + self._update_count = ShareableList([0]) + self._debug = False + self._verbose = False + + def get_id(self): + with self._instance_lock: + instance_id = self._env_count[0] + self._env_count[0] += 1 + return instance_id + + def put_task(self, task): + self.task_queue.put(task) + if self._debug: + task_count = self.added_task() + if self._verbose: + print(f"Task added to queue. Task count: {task_count}") + + def get_task(self): + task = self.task_queue.get() + if self._debug: + task_count = self.removed_task() + if self._verbose: + print(f"Task removed from queue. Task count: {task_count}") + return task + + def put_update(self, update): + self.update_queue.put(update) + if self._debug: + update_count = self.added_update() + if self._verbose: + print(f"Update added to queue. Update count: {update_count}") + + def get_update(self): + update = self.update_queue.get() + if self._debug: + update_count = self.removed_update() + if self._verbose: + print(f"Update removed from queue. Update count: {update_count}") + + return update + + def added_task(self): + with self._instance_lock: + self._task_count[0] += 1 + task_count = self._task_count[0] + return task_count + + def removed_task(self): + with self._instance_lock: + self._task_count[0] -= 1 + task_count = self._task_count[0] + return task_count + + def get_task_count(self): + with self._instance_lock: + task_count = self._task_count[0] + return task_count + + def get_update_count(self): + with self._instance_lock: + update_count = self._update_count[0] + return update_count + + def added_update(self): + with self._instance_lock: + self._update_count[0] += 1 + update_count = self._update_count[0] + return update_count + + def removed_update(self): + with self._instance_lock: + self._update_count[0] -= 1 + update_count = self._update_count[0] + return update_count + + +class MultiProcessingCurriculumWrapper(CurriculumWrapper): + def __init__( + self, + curriculum: Curriculum, + task_queue: SimpleQueue, + update_queue: SimpleQueue, + sequential_start: bool = True + ): + super().__init__(curriculum) + self.task_queue = task_queue + self.update_queue = update_queue + self.sequential_start = sequential_start + + self.update_thread = None + self.should_update = False + self.added_tasks = [] + self.num_assigned_tasks = 0 + + self._components = MultiProcessingComponents(task_queue, update_queue) + + def start(self): + """ + Start the thread that reads the complete_queue and reads the task_queue. + """ + self.update_thread = threading.Thread(name='update', target=self._update_queues, daemon=True) + self.should_update = True + self.update_thread.start() + + def stop(self): + """ + Stop the thread that reads the complete_queue and reads the task_queue. + """ + # Process final few updates + start = time.time() + end = time.time() + while end - start < 3 and not self.update_queue.empty(): + time.sleep(0.5) + end = time.time() + + self.should_update = False + components = self.get_components() + components._env_count.shm.close() + components._env_count.shm.unlink() + components._update_count.shm.close() + components._update_count.shm.unlink() + components._task_count.shm.close() + components._task_count.shm.unlink() + # components.task_queue.close() + # components.update_queue.close() + + def _update_queues(self): + """ + Continuously process completed tasks and sample new tasks. + """ + # TODO: Refactor long method? Write tests first + # Update curriculum with environment results: + while self.should_update: + requested_tasks = 0 + while not self.update_queue.empty(): + batch_updates = self.get_components().get_update() # Blocks until update is available + + if isinstance(batch_updates, dict): + batch_updates = [batch_updates] + + # Count number of requested tasks + for update in batch_updates: + if "request_sample" in update and update["request_sample"]: + requested_tasks += 1 + + self.update_batch(batch_updates) + + # Sample new tasks + if requested_tasks > 0: + new_tasks = self.curriculum.sample(k=requested_tasks) + for i, task in enumerate(new_tasks): + message = { + "next_task": task, + "sample_id": self.num_assigned_tasks + i, + } + + self.get_components().put_task(message) + self.num_assigned_tasks += requested_tasks + time.sleep(0) + else: + time.sleep(0.01) + + def log_metrics(self, writer, step=None): + super().log_metrics(writer, step=step) + if self.get_components()._debug: + writer.add_scalar("curriculum/updates_in_queue", self.get_components()._update_count[0], step) + writer.add_scalar("curriculum/tasks_in_queue", self.get_components()._task_count[0], step) + + def add_task(self, task): + super().add_task(task) + self.added_tasks.append(task) + + def get_components(self): + return self._components + + +def remote_call(func): + """ + Decorator for automatically forwarding calls to the curriculum via ray remote calls. + + Note that this causes functions to block, and should be only used for operations that do not require parallelization. + """ + @wraps(func) + def wrapper(self, *args, **kw): + f_name = func.__name__ + parent_func = getattr(CurriculumWrapper, f_name) + child_func = getattr(self, f_name) + + # Only forward call if subclass does not explicitly override the function. + if child_func == parent_func: + curriculum_func = getattr(self.curriculum, f_name) + return ray.get(curriculum_func.remote(*args, **kw)) + return wrapper + + +def make_multiprocessing_curriculum(curriculum, **kwargs): + """ + Helper function for creating a MultiProcessingCurriculumWrapper. + """ + task_queue = SimpleQueue() + update_queue = SimpleQueue() + + mp_curriculum = MultiProcessingCurriculumWrapper(curriculum, task_queue, update_queue, **kwargs) + mp_curriculum.start() + return mp_curriculum + + +@ray.remote +class RayWrapper(CurriculumWrapper): + def __init__(self, curriculum: Curriculum) -> None: + super().__init__(curriculum) + + +@decorate_all_functions(remote_call) +class RayCurriculumWrapper(CurriculumWrapper): + """ + Subclass of LearningProgress Curriculum that uses Ray to share tasks and receive feedback + from the environment. The only change is the @ray.remote decorator on the class. + + The @decorate_all_functions(remote_call) annotation automatically forwards all functions not explicitly + overridden here to the remote curriculum. This is intended to forward private functions of Curriculum subclasses + for convenience. + # TODO: Implement the Curriculum methods explicitly + """ + def __init__(self, curriculum, actor_name="curriculum") -> None: + super().__init__(curriculum) + self.curriculum = RayWrapper.options(name=actor_name).remote(curriculum) + self.unwrapped = None + self.task_space = curriculum.task_space + self.added_tasks = [] + + # If you choose to override a function, you will need to forward the call to the remote curriculum. + # This method is shown here as an example. If you remove it, the same functionality will be provided automatically. + def sample(self, k: int = 1): + return ray.get(self.curriculum.sample.remote(k=k)) + + def update_on_step_batch(self, step_results: List[Tuple[int, int, int, int]]) -> None: + ray.get(self.curriculum._on_step_batch.remote(step_results)) + + def add_task(self, task): + super().add_task(task) + self.added_tasks.append(task) + + +def make_ray_curriculum(curriculum, actor_name="curriculum", **kwargs): + """ + Helper function for creating a RayCurriculumWrapper. + """ + return RayCurriculumWrapper(curriculum, actor_name=actor_name, **kwargs) diff --git a/core/environment_sync_wrapper.py b/core/environment_sync_wrapper.py new file mode 100644 index 00000000..6edee7cc --- /dev/null +++ b/core/environment_sync_wrapper.py @@ -0,0 +1,348 @@ +from typing import Any, Callable, Dict + +import gymnasium as gym +import numpy as np +import ray +from gymnasium.utils.step_api_compatibility import step_api_compatibility + +from syllabus.core import Curriculum, MultiProcessingCurriculumWrapper, MultiProcessingComponents, TaskEnv, TaskWrapper +from syllabus.task_space import TaskSpace + + +class MultiProcessingSyncWrapper(gym.Wrapper): + """ + This wrapper is used to set the task on reset for a Gym environments running + on parallel processes created using multiprocessing.Process. Meant to be used + with a QueueLearningProgressCurriculum running on the main process. + """ + + def __init__(self, + env, + components: MultiProcessingComponents, + update_on_step: bool = False, # TODO: Fine grained control over which step elements are used. Controlled by curriculum? + update_on_progress: bool = False, # TODO: Fine grained control over which step elements are used. Controlled by curriculum? + batch_size: int = 100, + buffer_size: int = 2, # Having an extra task in the buffer minimizes wait time at reset + task_space: TaskSpace = None, + global_task_completion: Callable[[Curriculum, np.ndarray, float, bool, Dict[str, Any]], bool] = None): + # TODO: reimplement global task progress metrics + assert isinstance(task_space, TaskSpace), f"task_space must be a TaskSpace object. Got {type(task_space)} instead." + super().__init__(env) + self.env = env + self.components = components + self._latest_task = None + self.task_queue = components.task_queue + self.update_queue = components.update_queue + self.task_space = task_space + self.update_on_step = update_on_step + self.update_on_progress = update_on_progress + self.batch_size = batch_size + self.global_task_completion = global_task_completion + self.task_progress = 0.0 + self._batch_step = 0 + self.instance_id = components.get_id() + + self.episode_length = 0 + self.episode_return = 0 + + # Create batch buffers for step updates + if self.update_on_step: + self._obs = [None] * self.batch_size + self._rews = np.zeros(self.batch_size, dtype=np.float32) + self._terms = np.zeros(self.batch_size, dtype=bool) + self._truncs = np.zeros(self.batch_size, dtype=bool) + self._infos = [None] * self.batch_size + self._tasks = [None] * self.batch_size + self._task_progresses = [None] * self.batch_size + + # Request initial task + assert buffer_size > 0, "Buffer size must be greater than 0 to sample initial task for envs." + for _ in range(buffer_size): + update = { + "update_type": "noop", + "metrics": None, + "request_sample": True, + } + self.components.put_update(update) + + def reset(self, *args, **kwargs): + self.step_updates = [] + self.task_progress = 0.0 + self.episode_length = 0 + self.episode_return = 0 + + message = self.components.get_task() # Blocks until a task is available + next_task = self.task_space.decode(message["next_task"]) + self._latest_task = next_task + + # Add any new tasks + if "added_tasks" in message: + added_tasks = message["added_tasks"] + for add_task in added_tasks: + self.env.add_task(add_task) + return self.env.reset(*args, new_task=next_task, **kwargs) + + def step(self, action): + obs, rew, term, trunc, info = step_api_compatibility(self.env.step(action), output_truncation_bool=True) + self.episode_length += 1 + self.episode_return += rew + self.task_progress = info.get("task_completion", 0.0) + + # Update curriculum with step info + if self.update_on_step: + self._obs[self._batch_step] = obs + self._rews[self._batch_step] = rew + self._terms[self._batch_step] = term + self._truncs[self._batch_step] = trunc + self._infos[self._batch_step] = info + self._tasks[self._batch_step] = self.task_space.encode(self.get_task()) + self._task_progresses[self._batch_step] = self.task_progress + self._batch_step += 1 + + # Send batched updates + if self._batch_step >= self.batch_size or term or trunc: + updates = self._package_step_updates() + self.components.put_update(updates) + self._batch_step = 0 + + # Episode update + if term or trunc: + # Task progress + task_update = { + "update_type": "task_progress", + "metrics": ((self.task_space.encode(self.env.task), self.task_progress)), + "env_id": self.instance_id, + "request_sample": False, + } + episode_update = { + "update_type": "episode", + "metrics": (self.episode_return, self.episode_length, self.task_space.encode(self.env.task)), + "env_id": self.instance_id, + "request_sample": True + } + self.components.put_update([task_update, episode_update]) + + return obs, rew, term, trunc, info + + def _package_step_updates(self): + step_batch = { + "update_type": "step_batch", + "metrics": ([self._tasks[:self._batch_step], self._obs[:self._batch_step], self._rews[:self._batch_step], self._terms[:self._batch_step], self._truncs[:self._batch_step], self._infos[:self._batch_step]],), + "env_id": self.instance_id, + "request_sample": False + } + update = [step_batch] + + if self.update_on_progress: + task_batch = { + "update_type": "task_progress_batch", + "metrics": (self._tasks[:self._batch_step], self._task_progresses[:self._batch_step],), + "env_id": self.instance_id, + "request_sample": False + } + update.append(task_batch) + return update + + def add_task(self, task): + update = { + "update_type": "add_task", + "metrics": task + } + self.update_queue.put(update) + + def get_task(self): + # Allow user to reject task + if hasattr(self.env, "task"): + return self.env.task + return self._latest_task + + def __getattr__(self, attr): + env_attr = getattr(self.env, attr, None) + if env_attr is not None: + return env_attr + + +# TODO: Fix this and refactor +# class PettingZooMultiProcessingSyncWrapper(BaseParallelWraper): +# """ +# This wrapper is used to set the task on reset for a Gym environments running +# on parallel processes created using multiprocessing.Process. Meant to be used +# with a QueueLearningProgressCurriculum running on the main process. +# """ +# def __init__(self, +# env, +# task_queue: SimpleQueue, +# update_queue: SimpleQueue, +# update_on_step: bool = True, # TODO: Fine grained control over which step elements are used. Controlled by curriculum? +# default_task=None, +# task_space: TaskSpace = None, +# global_task_completion: Callable[[Curriculum, np.ndarray, float, bool, Dict[str, Any]], bool] = None): +# super().__init__(env) +# self.env = env +# self.task_queue = task_queue +# self.update_queue = update_queue +# self.task_space = task_space +# self.update_on_step = update_on_step +# self.global_task_completion = global_task_completion +# self.task_completion = 0.0 +# self.warned_once = False +# self.step_results = [] +# if task_space.contains(default_task): +# self.default_task = default_task + +# # Request initial task +# update = { +# "update_type": "noop", +# "metrics": None, +# "request_sample": True +# } +# self.update_queue.put(update) + +# @property +# def agents(self): +# return self.env.agents + +# def reset(self, *args, **kwargs): +# self.step_results = [] + +# # Update curriculum +# update = { +# "update_type": "complete", +# "metrics": (self.task_space.encode(self.env.task), self.task_completion), +# "request_sample": True +# } +# self.update_queue.put(update) +# self.task_completion = 0.0 + +# # Sample new task +# if self.task_queue.empty(): +# # Choose default task if it is set, or keep the current task +# next_task = self.default_task if self.default_task is not None else self.task_space.sample() +# if not self.warned_once: +# print("\nTask queue was empty, selecting default task. This warning will not print again for this environment.\n") +# self.warned_once = False +# else: +# message = self.task_queue.get() +# next_task = self.task_space.decode(message["next_task"]) +# if "add_task" in message: +# self.env.add_task(message["add_task"]) +# return self.env.reset(*args, new_task=next_task, **kwargs) + +# def step(self, action): +# obs, rew, term, trunc, info = self.env.step(action) + +# if "task_completion" in info: +# if self.global_task_completion is not None: +# self.task_completion = self.global_task_completion(self.curriculum, obs, rew, term, trunc, info) +# else: +# self.task_completion = info["task_completion"] + +# if self.update_on_step: +# self.step_results.append((obs, rew, term, trunc, info)) +# if len(self.step_results) >= 2000: +# update = { +# "update_type": "step_batch", +# "metrics": (self.step_results,), +# "request_sample": False +# } +# self.update_queue.put(update) +# self.step_results = [] + +# return obs, rew, term, trunc, info + +# def add_task(self, task): +# update = { +# "update_type": "add_task", +# "metrics": task +# } +# self.update_queue.put(update) + +# def __getattr__(self, attr): +# env_attr = getattr(self.env, attr, None) +# if env_attr: +# return env_attr + + +class RaySyncWrapper(gym.Wrapper): + """ + This wrapper is used to set the task on reset for a Gym environments running + on parallel processes created using ray. Meant to be used with a + RayLearningProgressCurriculum running on the main process. + """ + def __init__(self, + env, + update_on_step: bool = True, + task_space: gym.Space = None, + global_task_completion: Callable[[Curriculum, np.ndarray, float, bool, Dict[str, Any]], bool] = None): + assert isinstance(env, TaskWrapper) or isinstance(env, TaskEnv) or isinstance(env, PettingZooTaskWrapper), "Env must implement the task API" + super().__init__(env) + self.env = env + self.update_on_step = update_on_step # Disable to improve performance 10x + self.task_space = task_space + self.curriculum = ray.get_actor("curriculum") + self.task_completion = 0.0 + self.global_task_completion = global_task_completion + self.step_results = [] + + def reset(self, *args, **kwargs): + self.step_results = [] + + # Update curriculum + update = { + "update_type": "task_progress", + "metrics": (self.env.task, self.task_completion), + "request_sample": True + } + self.curriculum.update.remote(update) + self.task_completion = 0.0 + + # Sample new task + sample = ray.get(self.curriculum.sample.remote()) + next_task = sample[0] + + return self.env.reset(*args, new_task=next_task, **kwargs) + + def step(self, action): + obs, rew, term, trunc, info = self.env.step(action) + + if "task_completion" in info: + if self.global_task_completion is not None: + # TODO: Hide rllib interface? + self.task_completion = self.global_task_completion(self.curriculum, obs, rew, term, trunc, info) + else: + self.task_completion = info["task_completion"] + + # TODO: Optimize + if self.update_on_step: + self.step_results.append((obs, rew, term, trunc, info)) + if len(self.step_results) >= 1000 or term or trunc: + update = { + "update_type": "step_batch", + "metrics": (self.step_results,), + "request_sample": False + } + self.curriculum.update.remote(update) + self.step_results = [] + + return obs, rew, term, trunc, info + + def change_task(self, new_task): + """ + Changes the task of the existing environment to the new_task. + + Each environment will implement tasks differently. The easiest system would be to call a + function or set an instance variable to change the task. + + Some environments may need to be reset or even reinitialized to change the task. + If you need to reset or re-init the environment here, make sure to check + that it is not in the middle of an episode to avoid unexpected behavior. + """ + self.env.change_task(new_task) + + def add_task(self, task): + self.curriculum.add_task.remote(task) + + def __getattr__(self, attr): + env_attr = getattr(self.env, attr, None) + if env_attr: + return env_attr diff --git a/core/multivariate_curriculum_wrapper.py b/core/multivariate_curriculum_wrapper.py new file mode 100644 index 00000000..ca556366 --- /dev/null +++ b/core/multivariate_curriculum_wrapper.py @@ -0,0 +1,82 @@ +import itertools +import typing +from typing import Any, Callable, List, Union + +import numpy as np +from gymnasium.spaces import Dict, Tuple +from syllabus.core import Curriculum, CurriculumWrapper +from syllabus.task_space import TaskSpace + + +class MultitaskWrapper(CurriculumWrapper): + """ + Uniform sampling for task spaces with multiple subspaces (Tuple or Dict) + """ + # TODO: How do I use curriculum wrappers with the make_curriculum functions? + def __init__(self, *args, num_components: int, component_names: List[str] = None, **kwargs): + super().__init__(*args, **kwargs) + self.num_components = num_components + + # Duplicate task space for each component + if num_components is not None: + self.task_space = TaskSpace(Tuple([self.task_space.gym_space for _ in range(num_components)]), (tuple(self.task_space.tasks),) * num_components) + elif component_names is not None: + self.task_space = TaskSpace(Dict({name: self.task_space.gym_space for name in component_names}), {name: self.task_space.tasks for name in component_names}) + + def _sample_distribution(self) -> List[float]: + """ + Returns a sample distribution over the task space. + """ + # Uniform distribution + if isinstance(self.task_space.gym_space, Tuple): + multivariate_dists = [self.curriculum._sample_distribution() for _ in self.task_space.gym_space.spaces] + elif isinstance(self.task_space.gym_space, Dict): + multivariate_dists = {name: self.curriculum._sample_distribution() for name in self.task_space.gym_space.keys()} + else: + raise NotImplementedError("Multivariate task space must be Tuple or Dict.") + return multivariate_dists + + def sample(self, k: int = 1) -> Union[List, Any]: + """ + Sample k tasks from the curriculum. + """ + assert self.num_tasks > 0, "Task space is empty. Please add tasks to the curriculum before sampling." + + tasks = [] + for _ in range(k): + sample_dist = self._sample_distribution() + if isinstance(sample_dist, list): + task_components = [] + for dist in sample_dist: + task_components.append(self.curriculum.sample(k=1)[0]) + tasks.append(tuple(task_components)) + return tasks + + + multitask_dist = self._sample_distribution() + # TODO: Clean and comment + if isinstance(self.task_space.gym_space, Dict): + multitasks = [] + for _ in range(k): + multitask = {} + # TODO: Provide easier access to gym_space properties? + for (space_name, task_space), task_dist in zip(self.task_space.tasks.items(), multitask_dist): + n_tasks = len(task_dist) + task_idx = np.random.choice(list(range(n_tasks)), size=1, p=task_dist) + multitask[space_name] = np.array([self.get_tasks(task_space)[i] for i in task_idx]) + multitasks.append(multitask) + return multitasks + elif isinstance(self.task_space.gym_space, Tuple): + multitask = [] + for tasks, task_dist in zip(self.task_space.tasks, multitask_dist): + print(tasks) + n_tasks = len(task_dist) + task_idx = np.random.choice(list(range(n_tasks)), size=k, p=task_dist) + multitask.append(np.array([tasks[i] for i in task_idx])) + multitask = np.array(multitask) + return np.moveaxis(multitask, -1, 0) + else: + raise NotImplementedError("Multivariate task space must be Tuple or Dict.") + + def log_metrics(self, writer, step=None): + raise NotImplementedError("Multitask curriculum does not support logging metrics.") \ No newline at end of file diff --git a/core/task_interface/__init__.py b/core/task_interface/__init__.py new file mode 100644 index 00000000..d4deb940 --- /dev/null +++ b/core/task_interface/__init__.py @@ -0,0 +1,5 @@ +# Environment Code +from .task_wrapper import TaskWrapper # , PettingZooTaskWrapper +from .subclass_task_wrapper import SubclassTaskWrapper +from .reinit_task_wrapper import ReinitTaskWrapper +from .environment_task_env import TaskEnv # , PettingZooTaskEnv diff --git a/core/task_interface/environment_task_env.py b/core/task_interface/environment_task_env.py new file mode 100644 index 00000000..a4a9e2fd --- /dev/null +++ b/core/task_interface/environment_task_env.py @@ -0,0 +1,89 @@ +import gymnasium as gym +# import pettingzoo + + +class TaskEnv(gym.Env): + # TODO: Update to new TaskSpace API + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.task_completion = 0.0 + self.task_space = None + self.task = None + + def reset(self, *args, **kwargs): + if "new_task" in kwargs: + new_task = kwargs.pop("new_task") + self.change_task(new_task) + # TODO: Handle failure case for change task + self.task = new_task + + obs, info = super().reset(*args, **kwargs) + return self.observation(obs), info + + def change_task(self, new_task): + """ + Changes the task of the existing environment to the new_task. + + Each environment will implement tasks differently. The easiest system would be to call a + function or set an instance variable to change the task. + + Some environments may need to be reset or even reinitialized to change the task. + If you need to reset or re-init the environment here, make sure to check + that it is not in the middle of an episode to avoid unexpected behavior. + """ + raise NotImplementedError + + def add_task(self, task): + raise NotImplementedError("This environment does not support adding tasks.") + + def _task_completion(self, obs, rew, term, trunc, info) -> float: + """ + Implement this function to indicate whether the selected task has been completed. + This can be determined using the observation, rewards, term, trunc, info or internal values + from the environment. Intended to be used for automatic curricula. + Returns a boolean or float value indicating binary completion or scalar degree of completion. + """ + return 1.0 if term or trunc else 0.0 + + def _encode_goal(self): + """ + Implement this method to indicate which task is selected to the agent. + Returns: Numpy array encoding the goal. + """ + return None + + def observation(self, observation): + """ + Adds the goal encoding to the observation. + Override to add additional task-specific observations. + Returns a modified observation. + TODO: Complete this implementation and find way to support centralized encodings + """ + # Add goal to observation + goal_encoding = self._encode_goal() + if goal_encoding is not None: + observation['goal'] = goal_encoding + + return observation + + def step(self, action): + """ + Steps the environment with the given action. + Unlike the typical Gym environment, this method should also add the + {"task_completion": self.task_completion()} key to the info dictionary + to support curricula that rely on this metric. + """ + raise NotImplementedError + + +# class PettingZooTaskEnv(TaskEnv, pettingzoo.ParallelEnv): +# def __init__(self): +# super().__init__() +# self.task = None + +# @property +# def agents(self): +# return self.agents + +# def get_current_task(self): +# return self.task diff --git a/core/task_interface/reinit_task_wrapper.py b/core/task_interface/reinit_task_wrapper.py new file mode 100644 index 00000000..4e8a24aa --- /dev/null +++ b/core/task_interface/reinit_task_wrapper.py @@ -0,0 +1,119 @@ +""" Task wrapper for NLE that can change tasks at reset using the NLE's task definition format. """ +import time +from typing import Callable, Tuple, Union + +import gymnasium as gym + +from .task_wrapper import TaskWrapper + + +class ReinitTaskWrapper(TaskWrapper): + """ + This is a general wrapper for tasks defined as subclasses of a base environment. + + This wrapper reinitializes the environment with the provided env function at the start of each episode. + This is a simple, general solution to using Syllabus with tasks that need to be reinitialized, but it is inefficient. + It's likely that you can achieve better performance by using a more specialized wrapper. + """ + def __init__(self, env: gym.Env, env_fn: Callable, task_space: gym.Space = None): + super().__init__(env) + + self.env_fn = env_fn + self.task_envs = {} # Save instance of each task environment to avoid reinitializing + self.task_space = task_space + self.task = None + + def encode_task(self, task): + """ + Override to convert task description into an element of the MultiDiscrete task space. + This is the identity function by default. + """ + return task + + def decode_task(self, encoding): + """ + Override to convert element of the MultiDiscrete task space into format usable by the reinit env_fn. + This is the identity function by default. + """ + return encoding + + def reset(self, new_task: Union[Tuple, int, float] = None, **kwargs): + """ + Resets the environment along with all available tasks, and change the current task. + """ + # Change task if new one is provided + if new_task is not None: + self.change_task(new_task) + + obs, info = self.env.reset(**kwargs) + return self.observation(obs), info + + def change_task(self, new_task: Union[Tuple, int, float]): + """ + Change task by directly editing environment class. + + This ensures that all instance variables are reset, not just the ones for the current task. + We do this efficiently by keeping track of which reset functions have already been called, + since very few tasks override reset. If new_task is provided, we change the task before + calling the final reset. + """ + + # Update current task + if new_task not in self.task_envs: + self.task_envs[new_task] = self.env_fn(self.decode_task(new_task)) + + self.env = self.task_envs[new_task] + self.task = new_task + + def step(self, action): + """ + Step through environment and update task completion. + """ + obs, rew, term, trunc, info = self.env.step(action) + info["task_completion"] = self._task_completion(obs, rew, term, trunc, info) + return self.observation(obs), rew, term, trunc, info + + +if __name__ == "__main__": + from nle.env.tasks import (NetHackEat, NetHackGold, NetHackOracle, + NetHackScore, NetHackScout, NetHackStaircase, + NetHackStaircasePet) + + def run_episode(env, task: str = None, verbose=1): + env.reset(new_task=task) + task_name = type(env.unwrapped).__name__ + term = trunc = False + ep_rew = 0 + while not (term or trunc): + action = env.action_space.sample() + _, rew, term, trunc, _ = env.step(action) + ep_rew += rew + if verbose: + print(f"Episodic reward for {task_name}: {ep_rew}") + + print("Testing NethackTaskWrapper") + N_EPISODES = 100 + + # Initialize NLE + def create_env(task): + task_class = [NetHackScore, NetHackStaircase, NetHackStaircasePet, NetHackOracle, NetHackGold, NetHackEat, NetHackScout][task] + return task_class() + + nethack_env = NetHackScore() + nethack_task_env = ReinitTaskWrapper(nethack_env, create_env) + + start_time = time.time() + + for _ in range(N_EPISODES): + run_episode(nethack_task_env, verbose=0) + + end_time = time.time() + print(f"Run time same task: {end_time - start_time}") + start_time = time.time() + + for _ in range(N_EPISODES): + nethack_task = gym.spaces.Discrete(7).sample() + run_episode(nethack_task_env, task=nethack_task, verbose=0) + + end_time = time.time() + print(f"Run time swapping tasks: {end_time - start_time}") diff --git a/core/task_interface/subclass_task_wrapper.py b/core/task_interface/subclass_task_wrapper.py new file mode 100644 index 00000000..e558181c --- /dev/null +++ b/core/task_interface/subclass_task_wrapper.py @@ -0,0 +1,102 @@ +""" Task wrapper for NLE that can change tasks at reset using the NLE's task definition format. """ +import copy +from typing import List + +import gymnasium as gym +import numpy as np +from gymnasium import spaces +from syllabus.task_space import TaskSpace + +from .task_wrapper import TaskWrapper + + +class SubclassTaskWrapper(TaskWrapper): + # TODO: Automated tests + """ + This is a general wrapper for tasks defined as subclasses of a base environment. + + This wrapper reinitializes the environment with the provided env function at the start of each episode. + This is a simple, general solution to using Syllabus with tasks that need to be reinitialized, but it is inefficient. + It's likely that you can achieve better performance by using a more specialized wrapper. + """ + def __init__(self, env: gym.Env, task_subclasses: List[gym.Env] = None, **env_init_kwargs): + super().__init__(env) + + self.task_list = task_subclasses + self.task_space = TaskSpace(spaces.Discrete(len(self.task_list)), self.task_list) + self._env_init_kwargs = env_init_kwargs # kwargs for reinitializing the base environment + + # Add goal space to observation + self.observation_space = copy.deepcopy(self.env.observation_space) + self.observation_space["goal"] = spaces.MultiBinary(len(self.task_list)) + + # Tracking episode end + self.done = True + + # Initialize all tasks + original_class = self.env.__class__ + for task in self.task_list: + self.env.__class__ = task + self.env.__init__(**self._env_init_kwargs) + + self.env.__class__ = original_class + self.env.__init__(**self._env_init_kwargs) + + @property + def current_task(self): + return self.env.__class__ + + def _task_name(self, task): + return self.task.__name__ + + def reset(self, new_task: int = None, **kwargs): + """ + Resets the environment along with all available tasks, and change the current task. + """ + # Change task if new one is provided + if new_task is not None: + self.change_task(new_task) + + self.done = False + obs, info = self.env.reset(**kwargs) + return self.observation(obs), info + + def change_task(self, new_task: int): + """ + Change task by directly editing environment class. + + This ensures that all instance variables are reset, not just the ones for the current task. + We do this efficiently by keeping track of which reset functions have already been called, + since very few tasks override reset. If new_task is provided, we change the task before + calling the final reset. + """ + # Ignore new task if mid episode + if self.current_task.__init__ != self._task_class(new_task).__init__ and not self.done: + raise RuntimeError("Cannot change task mid-episode.") + + # Ignore if task is unknown + if new_task >= len(self.task_list): + raise RuntimeError(f"Unknown task {new_task}.") + + # Update current task + prev_task = self.task + self.task = new_task + self.env.__class__ = self._task_class(new_task) + + # If task requires reinitialization + if type(self.env).__init__ != prev_task.__init__: + self.env.__init__(**self._env_init_kwargs) + + def _encode_goal(self): + goal_encoding = np.zeros(len(self.task_list)) + goal_encoding[self.task] = 1 + return goal_encoding + + def step(self, action): + """ + Step through environment and update task completion. + """ + obs, rew, term, trunc, info = self.env.step(action) + self.done = term or trunc + info["task_completion"] = self._task_completion(obs, rew, term, trunc, info) + return self.observation(obs), rew, term, trunc, info diff --git a/core/task_interface/task_wrapper.py b/core/task_interface/task_wrapper.py new file mode 100644 index 00000000..89c94ff3 --- /dev/null +++ b/core/task_interface/task_wrapper.py @@ -0,0 +1,99 @@ +import gymnasium as gym +# import pettingzoo +# from pettingzoo.utils.wrappers.base_parallel import BaseParallelWraper + + +class TaskWrapper(gym.Wrapper): + # TODO: Update to new TaskSpace API + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.task_completion = 0.0 + self.task_space = None + self.task = None # TODO: Would making this a property protect from accidental overriding? + + def reset(self, *args, **kwargs): + if "new_task" in kwargs: + new_task = kwargs.pop("new_task") + self.change_task(new_task) + # TODO: Handle failure case for change task + self.task = new_task + return self.observation(super().reset(*args, **kwargs)) + + def change_task(self, new_task): + """ + Changes the task of the existing environment to the new_task. + + Each environment will implement tasks differently. The easiest system would be to call a + function or set an instance variable to change the task. + + Some environments may need to be reset or even reinitialized to change the task. + If you need to reset or re-init the environment here, make sure to check + that it is not in the middle of an episode to avoid unexpected behavior. + """ + raise NotImplementedError + + def add_task(self, task): + raise NotImplementedError("This environment does not support adding tasks.") + + def _task_completion(self, obs, rew, term, trunc, info) -> float: + """ + Implement this function to indicate whether the selected task has been completed. + This can be determined using the observation, rewards, term, trunc, info or internal values + from the environment. Intended to be used for automatic curricula. + Returns a boolean or float value indicating binary completion or scalar degree of completion. + """ + return 1.0 if term or trunc else 0.0 + + def _encode_goal(self): + """ + Implement this method to indicate which task is selected to the agent. + Returns: Numpy array encoding the goal. + """ + return None + + def observation(self, observation): + """ + Adds the goal encoding to the observation. + Override to add additional task-specific observations. + Returns a modified observation. + TODO: Complete this implementation and find way to support centralized encodings + """ + # Add goal to observation + goal_encoding = self._encode_goal() + if goal_encoding is not None: + observation['goal'] = goal_encoding + + return observation + + def step(self, action): + obs, rew, term, trunc, info = self.env.step(action) + + # Determine completion status of the current task + self.task_completion = self._task_completion(obs, rew, term, trunc, info) + info["task_completion"] = self.task_completion + + return self.observation(obs), rew, term, trunc, info + + def __getattr__(self, attr): + env_attr = self.env.__class__.__dict__.get(attr, None) + + if env_attr and callable(env_attr): + return env_attr + + +# class PettingZooTaskWrapper(TaskWrapper, BaseParallelWraper): +# def __init__(self, env: pettingzoo.ParallelEnv): +# super().__init__(env) +# self.task = None + +# @property +# def agents(self): +# return self.env.agents + +# def __getattr__(self, attr): +# env_attr = getattr(self.env, attr, None) +# if env_attr: +# return env_attr + +# def get_current_task(self): +# return self.current_task \ No newline at end of file diff --git a/core/utils.py b/core/utils.py new file mode 100644 index 00000000..538b9df0 --- /dev/null +++ b/core/utils.py @@ -0,0 +1,32 @@ +from itertools import product +from typing import Union + +import numpy as np + + +def decorate_all_functions(function_decorator): + def decorator(cls): + for base_cls in cls.__bases__: + for name, obj in vars(base_cls).items(): + parent_func = getattr(base_cls, name) + child_func = getattr(cls, name) + + # Only apply decorator to functions not overridden by subclass. + if callable(obj) and child_func == parent_func: + setattr(cls, name, function_decorator(obj)) + return cls + return decorator + + +class UsageError(Exception): + ... + pass + + +def enumerate_axes(list_or_size: Union[np.ndarray, int]): + if isinstance(list_or_size, int) or isinstance(list_or_size, np.int64): + return tuple(range(list_or_size)) + elif isinstance(list_or_size, list) or isinstance(list_or_size, np.ndarray): + return tuple(product(*[enumerate_axes(x) for x in list_or_size])) + else: + raise NotImplementedError(f"{type(list_or_size)}") diff --git a/curricula/__init__.py b/curricula/__init__.py new file mode 100644 index 00000000..650a40e7 --- /dev/null +++ b/curricula/__init__.py @@ -0,0 +1,11 @@ +import sys + +from .domain_randomization import DomainRandomization +from .learning_progress import LearningProgressCurriculum +from .noop import NoopCurriculum +from .plr.central_plr_wrapper import CentralizedPrioritizedLevelReplay +from .plr.plr_wrapper import PrioritizedLevelReplay +from .plr.task_sampler import TaskSampler +from .sequential import SequentialCurriculum +from .simple_box import SimpleBoxCurriculum +from .annealing_box import AnnealingBoxCurriculum diff --git a/curricula/annealing_box.py b/curricula/annealing_box.py new file mode 100644 index 00000000..101981c7 --- /dev/null +++ b/curricula/annealing_box.py @@ -0,0 +1,56 @@ +import typing +from typing import Any, List, Union, Sequence, SupportsFloat, SupportsInt, Tuple +import numpy as np + +from gymnasium.spaces import Box +from syllabus.core import Curriculum + +class AnnealingBoxCurriculum(Curriculum): + REQUIRES_STEP_UPDATES = True + REQUIRES_EPISODE_UPDATES = False + REQUIRES_CENTRAL_UPDATES = False + + + def __init__( + self, + *curriculum_args, + start_values: List[SupportsFloat], + end_values: List[SupportsFloat], + total_steps: Tuple[int, List[int]], + **curriculum_kwargs, + ): + super().__init__(*curriculum_args, **curriculum_kwargs) + assert isinstance( + self.task_space.gym_space, Box + ), "AnnealingBoxCurriculum only supports Box task spaces." + + self.start_values = np.array(start_values, dtype=np.float32) + self.end_values = np.array(end_values, dtype=np.float32) + + # Convert total_steps to list if necessary + if isinstance(total_steps, SupportsInt): + total_steps = [total_steps] + self.total_steps = np.array(total_steps, dtype=np.int32) + + assert len(self.start_values) == len(self.end_values), "Length of start_values and end_values must be the same." + assert all(x > 0 for x in self.total_steps), "All elements of total_steps must be greater than 0." + + self.current_step = 0 + + def update_on_step(self, *args, **kwargs) -> None: + """ + Update the curriculum based on the current training timestep. + """ + self.current_step += 1 + + def sample(self, k: int = 1) -> Union[List, Any]: + """ + Sample k tasks from the curriculum. + """ + # Linear annealing from start_values to end_values + annealed_values = ( + self.start_values + (self.end_values - self.start_values) * + np.minimum(self.current_step, self.total_steps) / self.total_steps + ) + + return [annealed_values.copy() for _ in range(k)] diff --git a/curricula/domain_randomization.py b/curricula/domain_randomization.py new file mode 100644 index 00000000..e4183904 --- /dev/null +++ b/curricula/domain_randomization.py @@ -0,0 +1,21 @@ +from typing import Any, List + +from syllabus.core import Curriculum + + +class DomainRandomization(Curriculum): + """A simple but strong baseline for curriculum learning that uniformly samples a task from the task space. + """ + REQUIRES_STEP_UPDATES = False + REQUIRES_EPISODE_UPDATES = False + REQUIRES_CENTRAL_UPDATES = False + + def _sample_distribution(self) -> List[float]: + """ + Returns a sample distribution over the task space. + """ + # Uniform distribution + return [1.0 / self.num_tasks for _ in range(self.num_tasks)] + + def add_task(self, task: Any) -> None: + self.task_space.add_task(task) diff --git a/curricula/learning_progress.py b/curricula/learning_progress.py new file mode 100644 index 00000000..e9e92f58 --- /dev/null +++ b/curricula/learning_progress.py @@ -0,0 +1,198 @@ +import math +import random +import warnings +from collections import defaultdict +from typing import List + +import numpy as np +from gymnasium.spaces import Discrete, MultiDiscrete +from scipy.stats import norm + +from syllabus.core import Curriculum +from syllabus.task_space import TaskSpace + + +class LearningProgressCurriculum(Curriculum): + """ + Provides an interface for tracking success rates of discrete tasks and sampling tasks + based on their success rate using the method from https://arxiv.org/abs/2106.14876. + TODO: Support task spaces aside from Discrete + """ + REQUIRES_STEP_UPDATES = False + REQUIRES_EPISODE_UPDATES = False + REQUIRES_CENTRAL_UPDATES = False + + def __init__(self, *args, ema_alpha=0.1, **kwargs): + super().__init__(*args, **kwargs) + self.ema_alpha = ema_alpha + + assert isinstance(self.task_space.gym_space, (Discrete, MultiDiscrete)) + self._p_fast = np.zeros(self.num_tasks) + self._p_slow = np.zeros(self.num_tasks) + + def update_task_progress(self, task: int, progress: float, env_id: int = None): + """ + Update the success rate for the given task using a fast and slow exponential moving average. + """ + if task is None or progress == 0.0: + return + super().update_task_progress(task, progress) + + self._p_fast[task] = (progress * self.ema_alpha) + (self._p_fast[task] * (1.0 - self.ema_alpha)) + self._p_slow[task] = (self._p_fast[task] * self.ema_alpha) + (self._p_slow[task] * (1.0 - self.ema_alpha)) + + def _learning_progress(self, task: int, reweight: bool = True) -> float: + """ + Compute the learning progress metric for the given task. + """ + slow = self._reweight(self._p_slow[task]) if reweight else self._p_slow[task] + fast = self._reweight(self._p_fast[task]) if reweight else self._p_fast[task] + return abs(fast - slow) + + def _reweight(self, p: np.ndarray, p_theta: float = 0.1) -> float: + """ + Reweight the given success rate using the reweighting function from the paper. + """ + numerator = p * (1.0 - p_theta) + denominator = p + p_theta * (1.0 - 2.0 * p) + return numerator / denominator + + def _sigmoid(self, x: np.ndarray): + return 1 / (1 + np.exp(-x)) + + def _sample_distribution(self) -> List[float]: + if self.num_tasks == 0: + return [] + + task_dist = np.ones(self.num_tasks) / self.num_tasks + + task_lps = self._learning_progress(np.asarray(self.tasks)) + posidxs = [i for i, lp in enumerate(task_lps) if lp > 0] + zeroout = len(posidxs) > 0 + + subprobs = task_lps[posidxs] if zeroout else task_lps + std = np.std(subprobs) + subprobs = (subprobs - np.mean(subprobs)) / (std if std else 1) # z-score + subprobs = self._sigmoid(subprobs) # sigmoid + subprobs = subprobs / np.sum(subprobs) # normalize + if zeroout: + # If some tasks have nonzero progress, zero out the rest + task_dist = np.zeros(len(task_lps)) + task_dist[posidxs] = subprobs + else: + # If all tasks have 0 progress, return uniform distribution + task_dist = subprobs + + return task_dist + + def on_step(self, obs, rew, term, trunc, info) -> None: + """ + Update the curriculum with the current step results from the environment. + """ + pass + + +if __name__ == "__main__": + def sample_binomial(p=0.5, n=200): + success = 0.0 + for _ in range(n): + rand = random.random() + if rand < p: + success += 1.0 + return success / n + + def generate_history(center=0, curve=1.0, n=100): + center = center if center else n / 2.0 + + def sig(x, x_0=center, curve=curve): + return 1.0 / (1.0 + math.e**(curve * (x_0 - x))) + history = [] + probs = [] + success_prob = 0.0 + for i in range(n): + probs.append(success_prob) + history.append(sample_binomial(p=success_prob)) + success_prob = sig(i) + return history, probs + + tasks = range(20) + histories = {task: generate_history(center=random.randint(0, 100), curve=random.random()) for task in tasks} + + curriculum = LearningProgressCurriculum(TaskSpace(len(tasks))) + for i in range(len(histories[0][0])): + for task in tasks: + curriculum.update_task_progress(task, histories[task][0][i]) + if i > 10: + distribution = curriculum._sample_distribution() + print("[", end="") + for j, prob in enumerate(distribution): + print(f"{prob:.3f}", end="") + if j < len(distribution) - 1: + print(", ", end="") + print("]") + + tasks = [0] + histories = {task: generate_history(n=200, center=75, curve=0.1) for task in tasks} + curriculum = LearningProgressCurriculum(TaskSpace(len(tasks))) + lp_raw = [] + lp_reweight = [] + p_fast = [] + p_slow = [] + true_probs = [] + estimates = [] + for estimate, true_prob in zip(histories[0][0], histories[0][1]): + curriculum.update_task_progress(tasks[0], estimate) + lp_raw.append(curriculum._learning_progress(tasks[0], reweight=False)) + lp_reweight.append(curriculum._learning_progress(tasks[0])) + p_fast.append(curriculum._p_fast[0]) + p_slow.append(curriculum._p_slow[0]) + true_probs.append(true_prob) + estimates.append(estimate) + + try: + import matplotlib.pyplot as plt + + # TODO: Plot probabilities + def plot_history(true_probs, estimates, p_slow, p_fast, lp_reweight, lp_raw): + x_axis = range(0, len(true_probs)) + plt.plot(x_axis, true_probs, color="#222222", label="True Success Probability") + plt.plot(x_axis, estimates, color="#888888", label="Estimated Success Probability") + plt.plot(x_axis, p_slow, color="#ee3333", label="p_slow") + plt.plot(x_axis, p_fast, color="#33ee33", label="p_fast") + plt.plot(x_axis, lp_raw, color="#c4c25b", label="Learning Progress") + plt.plot(x_axis, lp_reweight, color="#1544ee", label="Learning Progress Reweighted") + plt.xlabel('Time step') + plt.ylabel('Learning Progress') + plt.legend() + plt.show() + + plot_history(true_probs, estimates, p_slow, p_fast, lp_reweight, lp_raw) + + # Reweight Plot + x_axis = np.linspace(0, 1, num=100) + y_axis = [] + for x in x_axis: + y_axis.append(curriculum._reweight(x)) + plt.plot(x_axis, y_axis, color="blue", label="p_theta = 0.1") + plt.xlabel('p') + plt.ylabel('reweight') + plt.legend() + plt.show() + + # Z-score plot + tasks = [i for i in range(50)] + curriculum = LearningProgressCurriculum(TaskSpace(len(tasks))) + histories = {task: generate_history(n=200, center=60, curve=0.09) for task in tasks} + for i in range(len(histories[0][0])): + for task in tasks: + curriculum.update_task_progress(task, histories[task][0][i]) + distribution = curriculum._sample_distribution() + x_axis = np.linspace(-3, 3, num=len(distribution)) + sigmoid_axis = curriculum._sigmoid(x_axis) + plt.plot(x_axis, norm.pdf(x_axis, 0, 1), color="blue", label="Normal distribution") + plt.plot(x_axis, sigmoid_axis, color="orange", label="Sampling weight") + plt.xlabel('Z-scored distributed learning progress') + plt.legend() + plt.show() + except ImportError: + warnings.warn("Matplotlib not installed. Plotting will not work.") diff --git a/curricula/noop.py b/curricula/noop.py new file mode 100644 index 00000000..fb5d8ae4 --- /dev/null +++ b/curricula/noop.py @@ -0,0 +1,62 @@ +from typing import Any, List, Union + +from syllabus.core import Curriculum + + +class NoopCurriculum(Curriculum): + """ + Used to to test API without a curriculum. + """ + REQUIRES_STEP_UPDATES = True + REQUIRES_EPISODE_UPDATES = False + REQUIRES_CENTRAL_UPDATES = False + + def __init__(self, default_task, *curriculum_args, **curriculum_kwargs): + super().__init__(*curriculum_args, **curriculum_kwargs) + self.default_task = self.task_space.encode(default_task) + + def sample(self, k: int = 1) -> Union[List, Any]: + """ + Sample k tasks from the curriculum. + """ + return [self.default_task for _ in range(k)] + + def update_task_progress(self, task, success_prob, env_id: int = None) -> None: + """ + Update the curriculum with a task and its success probability upon + success or failure. + """ + pass + + def update_on_step(self, task, obs, rew, term, trunc, info, env_id: int = None) -> None: + """ + Update the curriculum with the current step results from the environment. + """ + pass + + def update_on_step_batch(self, step_results, env_id: int = None) -> None: + """ + Update the curriculum with a batch of step results from the environment. + """ + pass + + def update_on_episode(self, episode_return, episode_length, episode_task, env_id: int = None) -> None: + """ + Update the curriculum with episode results from the environment. + """ + pass + + def update_on_demand(self, metrics): + """ + Update the curriculum with arbitrary inputs. + """ + pass + + def add_task(self, task: tuple) -> None: + pass + + def update(self, update_data): + """ + Update the curriculum with the specified update type. + """ + pass diff --git a/curricula/plr/__init__.py b/curricula/plr/__init__.py new file mode 100644 index 00000000..75e6bc2d --- /dev/null +++ b/curricula/plr/__init__.py @@ -0,0 +1,3 @@ +from .central_plr_wrapper import CentralizedPrioritizedLevelReplay +from .plr_wrapper import PrioritizedLevelReplay +from .task_sampler import TaskSampler diff --git a/curricula/plr/central_plr_wrapper.py b/curricula/plr/central_plr_wrapper.py new file mode 100644 index 00000000..7f69ea85 --- /dev/null +++ b/curricula/plr/central_plr_wrapper.py @@ -0,0 +1,239 @@ +import warnings +from typing import Any, Dict, List, Tuple, Union + +import gymnasium as gym +import torch +from gymnasium.spaces import Discrete, MultiDiscrete + +from syllabus.core import Curriculum, enumerate_axes +from syllabus.task_space import TaskSpace + +from .task_sampler import TaskSampler + + +class RolloutStorage(object): + def __init__( + self, + num_steps: int, + num_processes: int, + requires_value_buffers: bool, + action_space: gym.Space = None, + ): + self._requires_value_buffers = requires_value_buffers + self.tasks = torch.zeros(num_steps, num_processes, 1, dtype=torch.int) + self.masks = torch.ones(num_steps + 1, num_processes, 1) + + if requires_value_buffers: + self.returns = torch.zeros(num_steps + 1, num_processes, 1) + self.rewards = torch.zeros(num_steps, num_processes, 1) + self.value_preds = torch.zeros(num_steps + 1, num_processes, 1) + else: + if action_space is None: + raise ValueError( + "Action space must be provided to PLR for strategies 'policy_entropy', 'least_confidence', 'min_margin'" + ) + self.action_log_dist = torch.zeros(num_steps, num_processes, action_space.n) + + self.num_steps = num_steps + self.step = 0 + + def to(self, device): + self.masks = self.masks.to(device) + self.tasks = self.tasks.to(device) + if self._requires_value_buffers: + self.rewards = self.rewards.to(device) + self.value_preds = self.value_preds.to(device) + self.returns = self.returns.to(device) + else: + self.action_log_dist = self.action_log_dist.to(device) + + def insert(self, masks, action_log_dist=None, value_preds=None, rewards=None, tasks=None): + if self._requires_value_buffers: + assert (value_preds is not None and rewards is not None), "Selected strategy requires value_preds and rewards" + if len(rewards.shape) == 3: + rewards = rewards.squeeze(2) + self.value_preds[self.step].copy_(torch.as_tensor(value_preds)) + self.rewards[self.step].copy_(torch.as_tensor(rewards)[:, None]) + self.masks[self.step + 1].copy_(torch.as_tensor(masks)[:, None]) + else: + self.action_log_dist[self.step].copy_(action_log_dist) + if tasks is not None: + assert isinstance(tasks[0], int), "Provided task must be an integer" + self.tasks[self.step].copy_(torch.as_tensor(tasks)[:, None]) + self.step = (self.step + 1) % self.num_steps + + def after_update(self): + self.masks[0].copy_(self.masks[-1]) + + def compute_returns(self, next_value, gamma, gae_lambda): + assert self._requires_value_buffers, "Selected strategy does not use compute_rewards." + self.value_preds[-1] = next_value + gae = 0 + for step in reversed(range(self.rewards.size(0))): + delta = ( + self.rewards[step] + + gamma * self.value_preds[step + 1] * self.masks[step + 1] + - self.value_preds[step] + ) + gae = delta + gamma * gae_lambda * self.masks[step + 1] * gae + self.returns[step] = gae + self.value_preds[step] + + +class CentralizedPrioritizedLevelReplay(Curriculum): + """ Prioritized Level Replay (PLR) Curriculum. + + Args: + task_space (TaskSpace): The task space to use for the curriculum. + *curriculum_args: Positional arguments to pass to the curriculum. + task_sampler_kwargs_dict (dict): Keyword arguments to pass to the task sampler. See TaskSampler for details. + action_space (gym.Space): The action space to use for the curriculum. Required for some strategies. + device (str): The device to use to store curriculum data, either "cpu" or "cuda". + num_steps (int): The number of steps to store in the rollouts. + num_processes (int): The number of parallel environments. + gamma (float): The discount factor used to compute returns + gae_lambda (float): The GAE lambda value. + suppress_usage_warnings (bool): Whether to suppress warnings about improper usage. + **curriculum_kwargs: Keyword arguments to pass to the curriculum. + """ + REQUIRES_STEP_UPDATES = False + REQUIRES_EPISODE_UPDATES = False + REQUIRES_CENTRAL_UPDATES = True + + def __init__( + self, + task_space: TaskSpace, + *curriculum_args, + task_sampler_kwargs_dict: dict = None, + action_space: gym.Space = None, + device: str = "cpu", + num_steps: int = 256, + num_processes: int = 64, + gamma: float = 0.999, + gae_lambda: float = 0.95, + suppress_usage_warnings=False, + **curriculum_kwargs, + ): + # Preprocess curriculum intialization args + if task_sampler_kwargs_dict is None: + task_sampler_kwargs_dict = {} + + self._strategy = task_sampler_kwargs_dict.get("strategy", None) + if not isinstance(task_space.gym_space, Discrete) and not isinstance(task_space.gym_space, MultiDiscrete): + raise ValueError( + f"Task space must be discrete or multi-discrete, got {task_space.gym_space}." + ) + if "num_actors" in task_sampler_kwargs_dict and task_sampler_kwargs_dict['num_actors'] != num_processes: + warnings.warn(f"Overwriting 'num_actors' {task_sampler_kwargs_dict['num_actors']} in task sampler kwargs with PLR num_processes {num_processes}.") + task_sampler_kwargs_dict["num_actors"] = num_processes + super().__init__(task_space, *curriculum_args, **curriculum_kwargs) + + self._num_steps = num_steps # Number of steps stored in rollouts and used to update task sampler + self._num_processes = num_processes # Number of parallel environments + self._gamma = gamma + self._gae_lambda = gae_lambda + self._supress_usage_warnings = suppress_usage_warnings + self._task2index = {task: i for i, task in enumerate(self.tasks)} + self._task_sampler = TaskSampler(self.tasks, action_space=action_space, **task_sampler_kwargs_dict) + self._rollouts = RolloutStorage( + self._num_steps, + self._num_processes, + self._task_sampler.requires_value_buffers, + action_space=action_space, + ) + self._rollouts.to(device) + # TODO: Fix this feature + self.num_updates = 0 # Used to ensure proper usage + self.num_samples = 0 # Used to ensure proper usage + + def _validate_metrics(self, metrics: Dict): + try: + masks = torch.Tensor(1 - metrics["dones"]) + tasks = metrics["tasks"] + tasks = [self._task2index[t] for t in tasks] + except KeyError as e: + raise KeyError( + "Missing or malformed PLR update. Must include 'masks', and 'tasks', and all tasks must be in the task space" + ) from e + + # Parse optional update values (required for some strategies) + value = next_value = rew = action_log_dist = None + if self._task_sampler.requires_value_buffers: + if "value" not in metrics or "rew" not in metrics: + raise KeyError( + f"'value' and 'rew' must be provided in every update for the strategy {self._strategy}." + ) + value = metrics["value"] + rew = metrics["rew"] + else: + try: + action_log_dist = metrics["action_log_dist"] + except KeyError as e: + raise KeyError( + f"'action_log_dist' must be provided in every update for the strategy {self._strategy}." + ) from e + + if self._task_sampler.requires_value_buffers: + try: + next_value = metrics["next_value"] + except KeyError as e: + raise KeyError( + f"'next_value' must be provided in the update every {self.num_steps} steps for the strategy {self._strategy}." + ) from e + + return masks, tasks, value, rew, action_log_dist, next_value + + def update_on_demand(self, metrics: Dict): + """ + Update the curriculum with arbitrary inputs. + """ + self.num_updates += 1 + masks, tasks, value, rew, action_log_dist, next_value = self._validate_metrics(metrics) + + # Update rollouts + self._rollouts.insert( + masks, + action_log_dist=action_log_dist, + value_preds=value, + rewards=rew, + tasks=tasks, + ) + + # Update task sampler + if self._rollouts.step == 0: + if self._task_sampler.requires_value_buffers: + self._rollouts.compute_returns(next_value, self._gamma, self._gae_lambda) + self._task_sampler.update_with_rollouts(self._rollouts) + self._rollouts.after_update() + self._task_sampler.after_update() + + def _sample_distribution(self) -> List[float]: + """ + Returns a sample distribution over the task space. + """ + return self._task_sampler.sample_weights() + + def sample(self, k: int = 1) -> Union[List, Any]: + self.num_samples += 1 + if self._should_use_startup_sampling(): + return self._startup_sample() + else: + return [self._task_sampler.sample() for _ in range(k)] + + def _enumerate_tasks(self, space): + assert isinstance(space, Discrete) or isinstance(space, MultiDiscrete), f"Unsupported task space {space}: Expected Discrete or MultiDiscrete" + if isinstance(space, Discrete): + return list(range(space.n)) + else: + return list(enumerate_axes(space.nvec)) + + def log_metrics(self, writer, step=None): + """ + Log the task distribution to the provided tensorboard writer. + """ + super().log_metrics(writer, step) + metrics = self._task_sampler.metrics() + writer.add_scalar("curriculum/proportion_seen", metrics["proportion_seen"], step) + writer.add_scalar("curriculum/score", metrics["score"], step) + for task in list(self.task_space.tasks)[:10]: + writer.add_scalar(f"curriculum/task_{task - 1}_score", metrics["task_scores"][task - 1], step) + writer.add_scalar(f"curriculum/task_{task - 1}_staleness", metrics["task_staleness"][task - 1], step) diff --git a/curricula/plr/plr_wrapper.py b/curricula/plr/plr_wrapper.py new file mode 100644 index 00000000..9c808ddc --- /dev/null +++ b/curricula/plr/plr_wrapper.py @@ -0,0 +1,295 @@ +import warnings +from typing import Any, Dict, List, Tuple, Union + +import gymnasium as gym +import numpy as np +import torch +from gymnasium.spaces import Discrete, MultiDiscrete + +from syllabus.core import Curriculum, UsageError, enumerate_axes +from syllabus.task_space import TaskSpace + +from .task_sampler import TaskSampler + + +class RolloutStorage(object): + def __init__( + self, + num_steps: int, + num_processes: int, + requires_value_buffers: bool, + observation_space: gym.Space, + action_space: gym.Space = None, + get_value=None, + ): + self.num_steps = num_steps + self.buffer_steps = num_steps * 4 # Hack to prevent overflow from lagging updates. + self.num_processes = num_processes + self._requires_value_buffers = requires_value_buffers + self._get_value = get_value + self.tasks = torch.zeros(self.buffer_steps, num_processes, 1, dtype=torch.int) + self.masks = torch.ones(self.buffer_steps + 1, num_processes, 1) + self.obs = [[[0] for _ in range(self.num_processes)]] * self.buffer_steps + self.env_steps = [0] * num_processes + self.ready_buffers = set() + + if requires_value_buffers: + self.returns = torch.zeros(self.buffer_steps + 1, num_processes, 1) + self.rewards = torch.zeros(self.buffer_steps, num_processes, 1) + self.value_preds = torch.zeros(self.buffer_steps + 1, num_processes, 1) + else: + if action_space is None: + raise ValueError( + "Action space must be provided to PLR for strategies 'policy_entropy', 'least_confidence', 'min_margin'" + ) + self.action_log_dist = torch.zeros(self.buffer_steps, num_processes, action_space.n) + + self.num_steps = num_steps + + def to(self, device): + self.masks = self.masks.to(device) + self.tasks = self.tasks.to(device) + if self._requires_value_buffers: + self.rewards = self.rewards.to(device) + self.value_preds = self.value_preds.to(device) + self.returns = self.returns.to(device) + else: + self.action_log_dist = self.action_log_dist.to(device) + + def insert_at_index(self, env_index, mask=None, action_log_dist=None, obs=None, reward=None, task=None, steps=1): + step = self.env_steps[env_index] + end_step = step + steps + + if mask is not None: + self.masks[step + 1:end_step + 1, env_index].copy_(torch.as_tensor(mask[:, None])) + + if obs is not None: + for s in range(step, end_step): + self.obs[s][env_index] = obs[s - step] + + if reward is not None: + self.rewards[step:end_step, env_index].copy_(torch.as_tensor(reward[:, None])) + + if action_log_dist is not None: + self.action_log_dist[step:end_step, env_index].copy_(torch.as_tensor(action_log_dist[:, None])) + + if task is not None: + try: + int(task[0]) + except TypeError: + assert isinstance(task, int), f"Provided task must be an integer, got {task[0]} with type {type(task[0])} instead." + self.tasks[step:end_step, env_index].copy_(torch.as_tensor(np.array(task)[:, None])) + + self.env_steps[env_index] += steps + if env_index not in self.ready_buffers and self.env_steps[env_index] >= self.num_steps: + self.ready_buffers.add(env_index) + + def _get_values(self, env_index): + if self._get_value is None: + raise UsageError("Selected strategy requires value predictions. Please provide get_value function.") + for step in range(0, self.num_steps, self.num_processes): + obs = self.obs[step: step + self.num_processes][env_index] + values = self._get_value(obs) + + # Reshape values if necessary + if len(values.shape) == 3: + warnings.warn(f"Value function returned a 3D tensor of shape {values.shape}. Attempting to squeeze last dimension.") + values = torch.squeeze(values, -1) + if len(values.shape) == 1: + warnings.warn(f"Value function returned a 1D tensor of shape {values.shape}. Attempting to unsqueeze last dimension.") + values = torch.unsqueeze(values, -1) + + self.value_preds[step: step + self.num_processes, env_index].copy_(values) + + def after_update(self, env_index): + # After consuming the first num_steps of data, remove them and shift the remaining data in the buffer + self.tasks = self.tasks.roll(-self.num_steps, 0) + self.masks = self.masks.roll(-self.num_steps, 0) + self.obs[0:][env_index] = self.obs[self.num_steps: self.buffer_steps][env_index] + + if self._requires_value_buffers: + self.returns = self.returns.roll(-self.num_steps, 0) + self.rewards = self.rewards.roll(-self.num_steps, 0) + self.value_preds = self.value_preds.roll(-self.num_steps, 0) + else: + self.action_log_dist = self.action_log_dist.roll(-self.num_steps, 0) + + self.env_steps[env_index] -= self.num_steps + self.ready_buffers.remove(env_index) + + def compute_returns(self, gamma, gae_lambda, env_index): + assert self._requires_value_buffers, "Selected strategy does not use compute_rewards." + self._get_values(env_index) + gae = 0 + for step in reversed(range(self.rewards.size(0), self.num_steps)): + delta = ( + self.rewards[step, env_index] + + gamma * self.value_preds[step + 1, env_index] * self.masks[step + 1, env_index] + - self.value_preds[step, env_index] + ) + gae = delta + gamma * gae_lambda * self.masks[step + 1, env_index] * gae + self.returns[step, env_index] = gae + self.value_preds[step, env_index] + + +def null(x): + return None + + +class PrioritizedLevelReplay(Curriculum): + """ Prioritized Level Replay (PLR) Curriculum. + + Args: + task_space (TaskSpace): The task space to use for the curriculum. + *curriculum_args: Positional arguments to pass to the curriculum. + task_sampler_kwargs_dict (dict): Keyword arguments to pass to the task sampler. See TaskSampler for details. + action_space (gym.Space): The action space to use for the curriculum. Required for some strategies. + device (str): The device to use to store curriculum data, either "cpu" or "cuda". + num_steps (int): The number of steps to store in the rollouts. + num_processes (int): The number of parallel environments. + gamma (float): The discount factor used to compute returns + gae_lambda (float): The GAE lambda value. + suppress_usage_warnings (bool): Whether to suppress warnings about improper usage. + **curriculum_kwargs: Keyword arguments to pass to the curriculum. + """ + REQUIRES_STEP_UPDATES = True + REQUIRES_EPISODE_UPDATES = False + REQUIRES_CENTRAL_UPDATES = False + + def __init__( + self, + task_space: TaskSpace, + observation_space: gym.Space, + *curriculum_args, + task_sampler_kwargs_dict: dict = None, + action_space: gym.Space = None, + device: str = "cpu", + num_steps: int = 256, + num_processes: int = 64, + gamma: float = 0.999, + gae_lambda: float = 0.95, + suppress_usage_warnings=False, + get_value=null, + get_action_log_dist=null, + **curriculum_kwargs, + ): + # Preprocess curriculum intialization args + if task_sampler_kwargs_dict is None: + task_sampler_kwargs_dict = {} + + self._strategy = task_sampler_kwargs_dict.get("strategy", None) + if not isinstance(task_space.gym_space, Discrete) and not isinstance(task_space.gym_space, MultiDiscrete): + raise ValueError( + f"Task space must be discrete or multi-discrete, got {task_space.gym_space}." + ) + if "num_actors" in task_sampler_kwargs_dict and task_sampler_kwargs_dict['num_actors'] != num_processes: + warnings.warn(f"Overwriting 'num_actors' {task_sampler_kwargs_dict['num_actors']} in task sampler kwargs with PLR num_processes {num_processes}.") + task_sampler_kwargs_dict["num_actors"] = num_processes + super().__init__(task_space, *curriculum_args, **curriculum_kwargs) + + self._num_steps = num_steps # Number of steps stored in rollouts and used to update task sampler + self._num_processes = num_processes # Number of parallel environments + self._gamma = gamma + self._gae_lambda = gae_lambda + self._supress_usage_warnings = suppress_usage_warnings + self._get_action_log_dist = get_action_log_dist + self._task2index = {task: i for i, task in enumerate(self.tasks)} + + self._task_sampler = TaskSampler(self.tasks, action_space=action_space, **task_sampler_kwargs_dict) + self._rollouts = RolloutStorage( + self._num_steps, + self._num_processes, + self._task_sampler.requires_value_buffers, + observation_space, + action_space=action_space, + get_value=get_value if get_value is not None else null, + ) + self._rollouts.to(device) + + def set_value_fn(self, value_fn): + self._rollouts._get_value = value_fn + + def _sample_distribution(self) -> List[float]: + """ + Returns a sample distribution over the task space. + """ + return self._task_sampler.sample_weights() + + def sample(self, k: int = 1) -> Union[List, Any]: + if self._should_use_startup_sampling(): + return self._startup_sample() + else: + return [self._task_sampler.sample() for _ in range(k)] + + def update_on_step(self, task, obs, rew, term, trunc, info, env_id: int = None) -> None: + """ + Update the curriculum with the current step results from the environment. + """ + assert env_id is not None, "env_id must be provided for PLR updates." + if env_id >= self._num_processes: + warnings.warn(f"Env index {env_id} is greater than the number of processes {self._num_processes}. Using index {env_id % self._num_processes} instead.") + env_id = env_id % self._num_processes + + # Update rollouts + self._rollouts.insert_at_index( + env_id, + mask=np.array([not (term or trunc)]), + action_log_dist=self._get_action_log_dist(obs), + reward=np.array([rew]), + obs=np.array([obs]), + ) + + # Update task sampler + if env_id in self._rollouts.ready_buffers: + self._update_sampler(env_id) + + def update_on_step_batch( + self, step_results: List[Tuple[int, Any, int, bool, bool, Dict]], env_id: int = None + ) -> None: + """ + Update the curriculum with a batch of step results from the environment. + """ + assert env_id is not None, "env_id must be provided for PLR updates." + if env_id >= self._num_processes: + warnings.warn(f"Env index {env_id} is greater than the number of processes {self._num_processes}. Using index {env_id % self._num_processes} instead.") + env_id = env_id % self._num_processes + + tasks, obs, rews, terms, truncs, infos = step_results + self._rollouts.insert_at_index( + env_id, + mask=np.logical_not(np.logical_or(terms, truncs)), + action_log_dist=self._get_action_log_dist(obs), + reward=rews, + obs=obs, + steps=len(rews), + task=tasks, + ) + + # Update task sampler + if env_id in self._rollouts.ready_buffers: + self._update_sampler(env_id) + + def _update_sampler(self, env_id): + if self._task_sampler.requires_value_buffers: + self._rollouts.compute_returns(self._gamma, self._gae_lambda, env_id) + self._task_sampler.update_with_rollouts(self._rollouts, env_id) + self._rollouts.after_update(env_id) + self._task_sampler.after_update() + + def _enumerate_tasks(self, space): + assert isinstance(space, Discrete) or isinstance(space, MultiDiscrete), f"Unsupported task space {space}: Expected Discrete or MultiDiscrete" + if isinstance(space, Discrete): + return list(range(space.n)) + else: + return list(enumerate_axes(space.nvec)) + + def log_metrics(self, writer, step=None): + """ + Log the task distribution to the provided tensorboard writer. + """ + # super().log_metrics(writer, step) + metrics = self._task_sampler.metrics() + writer.add_scalar("curriculum/proportion_seen", metrics["proportion_seen"], step) + writer.add_scalar("curriculum/score", metrics["score"], step) + # for task in list(self.task_space.tasks)[:10]: + # writer.add_scalar(f"curriculum/task_{task - 1}_score", metrics["task_scores"][task - 1], step) + # writer.add_scalar(f"curriculum/task_{task - 1}_staleness", metrics["task_staleness"][task - 1], step) diff --git a/curricula/plr/task_sampler.py b/curricula/plr/task_sampler.py new file mode 100644 index 00000000..c1e97a18 --- /dev/null +++ b/curricula/plr/task_sampler.py @@ -0,0 +1,354 @@ +# Code heavily based on the original Prioritized Level Replay implementation from https://github.com/facebookresearch/level-replay +# If you use this code, please cite the above codebase and original PLR paper: https://arxiv.org/abs/2010.03934 +import gymnasium as gym +import numpy as np +import torch + + +class TaskSampler: + """ Task sampler for Prioritized Level Replay (PLR) + + Args: + tasks (list): List of tasks to sample from + action_space (gym.spaces.Space): Action space of the environment + num_actors (int): Number of actors/processes + strategy (str): Strategy for sampling tasks. One of "value_l1", "gae", "policy_entropy", "least_confidence", "min_margin", "one_step_td_error". + replay_schedule (str): Schedule for sampling replay levels. One of "fixed" or "proportionate". + score_transform (str): Transform to apply to task scores. One of "constant", "max", "eps_greedy", "rank", "power", "softmax". + temperature (float): Temperature for score transform. Increasing temperature makes the sampling distribution more uniform. + eps (float): Epsilon for eps-greedy score transform. + rho (float): Proportion of seen tasks before replay sampling is allowed. + nu (float): Probability of sampling a replay level if using a fixed replay_schedule. + alpha (float): Linear interpolation weight for score updates. 0.0 means only use old scores, 1.0 means only use new scores. + staleness_coef (float): Linear interpolation weight for task staleness vs. task score. 0.0 means only use task score, 1.0 means only use staleness. + staleness_transform (str): Transform to apply to task staleness. One of "constant", "max", "eps_greedy", "rank", "power", "softmax". + staleness_temperature (float): Temperature for staleness transform. Increasing temperature makes the sampling distribution more uniform. + """ + def __init__( + self, + tasks: list, + action_space: gym.spaces.Space = None, + num_actors: int = 1, + strategy: str = "value_l1", + replay_schedule: str = "proportionate", + score_transform: str = "rank", + temperature: float = 0.1, + eps: float = 0.05, + rho: float = 1.0, + nu: float = 0.5, + alpha: float = 1.0, + staleness_coef: float = 0.1, + staleness_transform: str = "power", + staleness_temperature: float = 1.0, + ): + self.action_space = action_space + self.tasks = tasks + self.num_tasks = len(self.tasks) + + self.strategy = strategy + self.replay_schedule = replay_schedule + self.score_transform = score_transform + self.temperature = temperature + self.eps = eps + self.rho = rho + self.nu = nu + self.alpha = float(alpha) + self.staleness_coef = staleness_coef + self.staleness_transform = staleness_transform + self.staleness_temperature = staleness_temperature + + self.unseen_task_weights = np.array([1.0] * self.num_tasks) + self.task_scores = np.array([0.0] * self.num_tasks, dtype=float) + self.partial_task_scores = np.zeros((num_actors, self.num_tasks), dtype=float) + self.partial_task_steps = np.zeros((num_actors, self.num_tasks), dtype=np.int64) + self.task_staleness = np.array([0.0] * self.num_tasks, dtype=float) + + self.next_task_index = 0 # Only used for sequential strategy + + # Logging metrics + self._last_score = 0.0 + + if not self.requires_value_buffers and self.action_space is None: + raise ValueError( + 'Must provide action space to PLR if using "policy_entropy", "least_confidence", or "min_margin" strategies' + ) + + def update_with_rollouts(self, rollouts, actor_id=None): + if self.strategy == "random": + return + + # Update with a RolloutStorage object + if self.strategy == "policy_entropy": + score_function = self._average_entropy + elif self.strategy == "least_confidence": + score_function = self._average_least_confidence + elif self.strategy == "min_margin": + score_function = self._average_min_margin + elif self.strategy == "gae": + score_function = self._average_gae + elif self.strategy == "value_l1": + score_function = self._average_value_l1 + elif self.strategy == "one_step_td_error": + score_function = self._one_step_td_error + else: + raise ValueError(f"Unsupported strategy, {self.strategy}") + + self._update_with_rollouts(rollouts, score_function, actor_index=actor_id) + + def update_task_score(self, actor_index, task_idx, score, num_steps): + score = self._partial_update_task_score(actor_index, task_idx, score, num_steps, done=True) + + self.unseen_task_weights[task_idx] = 0.0 # No longer unseen + + old_score = self.task_scores[task_idx] + self.task_scores[task_idx] = (1.0 - self.alpha) * old_score + self.alpha * score + + def _partial_update_task_score(self, actor_index, task_idx, score, num_steps, done=False): + partial_score = self.partial_task_scores[actor_index][task_idx] + partial_num_steps = self.partial_task_steps[actor_index][task_idx] + + running_num_steps = partial_num_steps + num_steps + merged_score = partial_score + (score - partial_score) * num_steps / float(running_num_steps) + if done: + self.partial_task_scores[actor_index][task_idx] = 0.0 # zero partial score, partial num_steps + self.partial_task_steps[actor_index][task_idx] = 0 + else: + self.partial_task_scores[actor_index][task_idx] = merged_score + self.partial_task_steps[actor_index][task_idx] = running_num_steps + + return merged_score + + def _average_entropy(self, **kwargs): + episode_logits = kwargs["episode_logits"] + num_actions = self.action_space.n + max_entropy = -(1.0 / num_actions) * np.log(1.0 / num_actions) * num_actions + + return (-torch.exp(episode_logits) * episode_logits).sum(-1).mean().item() / max_entropy + + def _average_least_confidence(self, **kwargs): + episode_logits = kwargs["episode_logits"] + return (1 - torch.exp(episode_logits.max(-1, keepdim=True)[0])).mean().item() + + def _average_min_margin(self, **kwargs): + episode_logits = kwargs["episode_logits"] + top2_confidence = torch.exp(episode_logits.topk(2, dim=-1)[0]) + return 1 - (top2_confidence[:, 0] - top2_confidence[:, 1]).mean().item() + + def _average_gae(self, **kwargs): + returns = kwargs["returns"] + value_preds = kwargs["value_preds"] + + advantages = returns - value_preds + + return advantages.mean().item() + + def _average_value_l1(self, **kwargs): + returns = kwargs["returns"] + value_preds = kwargs["value_preds"] + + advantages = returns - value_preds + + return advantages.abs().mean().item() + + def _one_step_td_error(self, **kwargs): + rewards = kwargs["rewards"] + value_preds = kwargs["value_preds"] + + max_t = len(rewards) + td_errors = (rewards[:-1] + value_preds[: max_t - 1] - value_preds[1:max_t]).abs() + assert not torch.isnan( + td_errors.abs().mean() + ), f"Got invalid values for 'rewards' or 'value_preds'. Check that reward length: {len(rewards)}" + return td_errors.abs().mean().item() + + @property + def requires_value_buffers(self): + return self.strategy in ["gae", "value_l1", "one_step_td_error"] + + def _update_with_rollouts(self, rollouts, score_function, actor_index=None): + tasks = rollouts.tasks + if not self.requires_value_buffers: + policy_logits = rollouts.action_log_dist + done = ~(rollouts.masks > 0) + total_steps, num_actors = rollouts.tasks.shape[:2] + + actors = [actor_index] if actor_index is not None else range(num_actors) + for actor_index in actors: + done_steps = done[:, actor_index].nonzero()[:total_steps, 0] + start_t = 0 + + for t in done_steps: + if not start_t < total_steps: + break + + if (t == 0): # if t is 0, then this done step caused a full update of previous last cycle + continue + + # If there is only 1 step, we can't calculate the one-step td error + if self.strategy == "one_step_td_error" and t - start_t <= 1: + continue + + task_idx_t = tasks[start_t, actor_index].item() + + # Store kwargs for score function + score_function_kwargs = {} + if self.requires_value_buffers: + score_function_kwargs["returns"] = rollouts.returns[start_t:t, actor_index] + score_function_kwargs["rewards"] = rollouts.rewards[start_t:t, actor_index] + score_function_kwargs["value_preds"] = rollouts.value_preds[start_t:t, actor_index] + else: + episode_logits = policy_logits[start_t:t, actor_index] + score_function_kwargs["episode_logits"] = torch.log_softmax(episode_logits, -1) + score = score_function(**score_function_kwargs) + num_steps = len(rollouts.tasks[start_t:t, actor_index]) + # TODO: Check that task_idx_t is correct + self.update_task_score(actor_index, task_idx_t, score, num_steps) + + start_t = t.item() + if start_t < total_steps: + # If there is only 1 step, we can't calculate the one-step td error + if self.strategy == "one_step_td_error" and start_t == total_steps - 1: + continue + # TODO: Check this too + task_idx_t = tasks[start_t, actor_index].item() + + # Store kwargs for score function + score_function_kwargs = {} + if self.requires_value_buffers: + score_function_kwargs["returns"] = rollouts.returns[start_t:, actor_index] + score_function_kwargs["rewards"] = rollouts.rewards[start_t:, actor_index] + score_function_kwargs["value_preds"] = rollouts.value_preds[start_t:, actor_index] + else: + episode_logits = policy_logits[start_t:, actor_index] + score_function_kwargs["episode_logits"] = torch.log_softmax(episode_logits, -1) + + score = score_function(**score_function_kwargs) + self._last_score = score + num_steps = len(rollouts.tasks[start_t:, actor_index]) + self._partial_update_task_score(actor_index, task_idx_t, score, num_steps) + + def after_update(self): + # Reset partial updates, since weights have changed, and thus logits are now stale + for actor_index in range(self.partial_task_scores.shape[0]): + for task_idx in range(self.partial_task_scores.shape[1]): + if self.partial_task_scores[actor_index][task_idx] != 0: + self.update_task_score(actor_index, task_idx, 0, 0) + self.partial_task_scores.fill(0) + self.partial_task_steps.fill(0) + + def _update_staleness(self, selected_idx): + if self.staleness_coef > 0: + self.task_staleness = self.task_staleness + 1 + self.task_staleness[selected_idx] = 0 + + def _sample_replay_level(self): + sample_weights = self.sample_weights() + if np.isclose(np.sum(sample_weights), 0): + sample_weights = np.ones_like(sample_weights, dtype=float) / len(sample_weights) + + task_idx = np.random.choice(range(self.num_tasks), 1, p=sample_weights)[0] + task = self.tasks[task_idx] + + self._update_staleness(task_idx) + + return task + + def _sample_unseen_level(self): + sample_weights = self.unseen_task_weights / self.unseen_task_weights.sum() + task_idx = np.random.choice(range(self.num_tasks), 1, p=sample_weights)[0] + task = self.tasks[task_idx] + + self._update_staleness(task_idx) + + return task + + def sample(self, strategy=None): + if not strategy: + strategy = self.strategy + + if strategy == "random": + task_idx = np.random.choice(range((self.num_tasks))) + task = self.tasks[task_idx] + return task + + if strategy == "sequential": + task_idx = self.next_task_index + self.next_task_index = (self.next_task_index + 1) % self.num_tasks + task = self.tasks[task_idx] + return task + + num_unseen = (self.unseen_task_weights > 0).sum() + proportion_seen = (self.num_tasks - num_unseen) / self.num_tasks + + if self.replay_schedule == "fixed": + if proportion_seen >= self.rho: + # Sample replay level with fixed prob = 1 - nu OR if all levels seen + if np.random.rand() > self.nu or not proportion_seen < 1.0: + return self._sample_replay_level() + + # Otherwise, sample a new level + return self._sample_unseen_level() + + elif self.replay_schedule == "proportionate": + if proportion_seen >= self.rho and np.random.rand() < proportion_seen: + return self._sample_replay_level() + else: + return self._sample_unseen_level() + else: + raise NotImplementedError(f"Unsupported replay schedule: {self.replay_schedule}. Must be 'fixed' or 'proportionate'.") + + def sample_weights(self): + weights = self._score_transform(self.score_transform, self.temperature, self.task_scores) + weights = weights * (1 - self.unseen_task_weights) # zero out unseen levels + z = np.sum(weights) + if z > 0: + weights /= z + + staleness_weights = 0 + if self.staleness_coef > 0: + staleness_weights = self._score_transform( + self.staleness_transform, + self.staleness_temperature, + self.task_staleness, + ) + staleness_weights = staleness_weights * (1 - self.unseen_task_weights) + z = np.sum(staleness_weights) + if z > 0: + staleness_weights /= z + weights = (1 - self.staleness_coef) * weights + self.staleness_coef * staleness_weights + return weights + + def _score_transform(self, transform, temperature, scores): + if transform == "constant": + weights = np.ones_like(scores) + if transform == "max": + weights = np.zeros_like(scores) + scores = scores[:] + scores[self.unseen_task_weights > 0] = -float("inf") # only argmax over seen levels + argmax = np.random.choice(np.flatnonzero(np.isclose(scores, scores.max()))) + weights[argmax] = 1.0 + elif transform == "eps_greedy": + weights = np.zeros_like(scores) + weights[scores.argmax()] = 1.0 - self.eps + weights += self.eps / self.num_tasks + elif transform == "rank": + temp = np.flip(scores.argsort()) + ranks = np.empty_like(temp) + ranks[temp] = np.arange(len(temp)) + 1 + weights = 1 / ranks ** (1.0 / temperature) + elif transform == "power": + eps = 0 if self.staleness_coef > 0 else 1e-3 + weights = (np.array(scores) + eps) ** (1.0 / temperature) + elif transform == "softmax": + weights = np.exp(np.array(scores) / temperature) + + return weights + + def metrics(self): + return { + "task_scores": self.task_scores, + "unseen_task_weights": self.unseen_task_weights, + "task_staleness": self.task_staleness, + "proportion_seen": (self.num_tasks - (self.unseen_task_weights > 0).sum()) / self.num_tasks, + "score": self._last_score, + } diff --git a/curricula/sequential.py b/curricula/sequential.py new file mode 100644 index 00000000..ec3b8b09 --- /dev/null +++ b/curricula/sequential.py @@ -0,0 +1,207 @@ +import re +import warnings +from typing import Any, Callable, List, Union + +from syllabus.core import Curriculum +from syllabus.curricula import NoopCurriculum, DomainRandomization +from syllabus.task_space import TaskSpace + + +class SequentialCurriculum(Curriculum): + REQUIRES_STEP_UPDATES = False + REQUIRES_EPISODE_UPDATES = True + REQUIRES_CENTRAL_UPDATES = False + + def __init__(self, curriculum_list: List[Curriculum], stopping_conditions: List[Any], *curriculum_args, **curriculum_kwargs): + super().__init__(*curriculum_args, **curriculum_kwargs) + assert len(curriculum_list) > 0, "Must provide at least one curriculum" + assert len(stopping_conditions) == len(curriculum_list) - 1, f"Stopping conditions must be one less than the number of curricula. Final curriculum is used for the remainder of training. Expected {len(curriculum_list) - 1}, got {len(stopping_conditions)}." + if len(curriculum_list) == 1: + warnings.warn("Your sequential curriculum only containes one element. Consider using that element directly instead.") + + self.curriculum_list = self._parse_curriculum_list(curriculum_list) + self.stopping_conditions = self._parse_stopping_conditions(stopping_conditions) + self._curriculum_index = 0 + + # Stopping metrics + self.n_steps = 0 + self.total_steps = 0 + self.n_episodes = 0 + self.total_episodes = 0 + self.n_tasks = 0 + self.total_tasks = 0 + self.episode_returns = [] + + def _parse_curriculum_list(self, curriculum_list: List[Curriculum]) -> List[Curriculum]: + """ Parse the curriculum list to ensure that all items are curricula. + Adds Curriculum objects directly. Wraps task space items in NoopCurriculum objects. + """ + parsed_list = [] + for item in curriculum_list: + if isinstance(item, Curriculum): + parsed_list.append(item) + elif isinstance(item, TaskSpace): + parsed_list.append(DomainRandomization(item)) + elif isinstance(item, list): + task_space = TaskSpace(len(item), item) + parsed_list.append(DomainRandomization(task_space)) + elif self.task_space.contains(item): + parsed_list.append(NoopCurriculum(item, self.task_space)) + else: + raise ValueError(f"Invalid curriculum item: {item}") + + return parsed_list + + def _parse_stopping_conditions(self, stopping_conditions: List[Any]) -> List[Any]: + """ Parse the stopping conditions to ensure that all items are integers. """ + parsed_list = [] + for item in stopping_conditions: + if isinstance(item, Callable): + parsed_list.append(item) + elif isinstance(item, str): + parsed_list.append(self._parse_condition_string(item)) + else: + raise ValueError(f"Invalid stopping condition: {item}") + + return parsed_list + + def _parse_condition_string(self, condition: str) -> Callable: + """ Parse a string condition to a callable function. """ + + # Parse composite conditions + if '|' in condition: + conditions = re.split(re.escape('|'), condition) + return lambda: any(self._parse_condition_string(cond)() for cond in conditions) + elif '&' in condition: + conditions = re.split(re.escape('&'), condition) + return lambda: all(self._parse_condition_string(cond)() for cond in conditions) + + clauses = re.split('(<=|>=|=|<|>)', condition) + + try: + metric, comparator, value = clauses + + if metric == "steps": + metric_fn = self._get_steps + elif metric == "total_steps": + metric_fn = self._get_total_steps + elif metric == "episodes": + metric_fn = self._get_episodes + elif metric == "total_episodes": + metric_fn = self._get_total_episodes + elif metric == "tasks": + metric_fn = self._get_tasks + elif metric == "total_tasks": + metric_fn = self._get_total_tasks + elif metric == "episode_return": + metric_fn = self._get_episode_return + else: + raise ValueError(f"Invalid metric name: {metric}") + + if comparator == '<': + return lambda: metric_fn() < float(value) + elif comparator == '>': + return lambda: metric_fn() > float(value) + elif comparator == '<=': + return lambda: metric_fn() <= float(value) + elif comparator == '>=': + return lambda: metric_fn() >= float(value) + elif comparator == '=': + return lambda: metric_fn() == float(value) + else: + raise ValueError(f"Invalid comparator: {comparator}") + except ValueError as e: + raise ValueError(f"Invalid condition string: {condition}") from e + + def _get_steps(self): + return self.n_steps + + def _get_total_steps(self): + return self.total_steps + + def _get_episodes(self): + return self.n_episodes + + def _get_total_episodes(self): + return self.total_episodes + + def _get_tasks(self): + return self.n_tasks + + def _get_total_tasks(self): + return self.total_tasks + + def _get_episode_return(self): + return sum(self.episode_returns) / len(self.episode_returns) if len(self.episode_returns) > 0 else 0 + + @property + def current_curriculum(self): + return self.curriculum_list[self._curriculum_index] + + @property + def requires_step_updates(self): + return any(map(lambda c: c.requires_step_updates, self.curriculum_list)) + + def _sample_distribution(self) -> List[float]: + """ + Return None to indicate that tasks are not drawn from a distribution. + """ + return None + + def sample(self, k: int = 1) -> Union[List, Any]: + """ + Choose the next k tasks from the list. + """ + curriculum = self.current_curriculum + tasks = curriculum.sample(k) + + # Recode tasks into environment task space + decoded_tasks = [curriculum.task_space.decode(task) for task in tasks] + recoded_tasks = [self.task_space.encode(task) for task in decoded_tasks] + + self.n_tasks += k + self.total_tasks += k + + # Check if we should move on to the next phase of the curriculum + self.check_stopping_conditions() + return recoded_tasks + + def update_on_episode(self, episode_return, episode_len, episode_task, env_id=None): + self.n_episodes += 1 + self.total_episodes += 1 + self.n_steps += episode_len + self.total_steps += episode_len + self.episode_returns.append(episode_return) + + # Update current curriculum + if self.current_curriculum.requires_episode_updates: + self.current_curriculum.update_on_episode(episode_return, episode_len, episode_task, env_id) + + def update_on_step(self, task, obs, rew, term, trunc, info, env_id=None): + if self.current_curriculum.requires_step_updates: + self.current_curriculum.update_on_step(task, obs, rew, term, trunc, info, env_id) + + def update_on_step_batch(self, step_results, env_id=None): + if self.current_curriculum.requires_step_updates: + self.current_curriculum.update_on_step_batch(step_results, env_id) + + def update_on_demand(self, metrics): + self.current_curriculum.update_on_demand(metrics) + + def update_task_progress(self, task, progress, env_id=None): + self.current_curriculum.update_task_progress(task, progress, env_id) + + def check_stopping_conditions(self): + if self._curriculum_index < len(self.stopping_conditions) and self.stopping_conditions[self._curriculum_index](): + self._curriculum_index += 1 + self.n_episodes = 0 + self.n_steps = 0 + self.episode_returns = [] + self.n_tasks = 0 + + def log_metrics(self, writer, step=None, log_full_dist=False): + # super().log_metrics(writer, step, log_full_dist) + writer.add_scalar("curriculum/current_stage", self._curriculum_index, step) + writer.add_scalar("curriculum/steps", self.n_steps, step) + writer.add_scalar("curriculum/episodes", self.n_episodes, step) + writer.add_scalar("curriculum/episode_returns", self._get_episode_return(), step) diff --git a/curricula/simple_box.py b/curricula/simple_box.py new file mode 100644 index 00000000..97ed6e7e --- /dev/null +++ b/curricula/simple_box.py @@ -0,0 +1,63 @@ +import typing +from typing import Any, List, Union + +from gymnasium.spaces import Box + +from syllabus.core import Curriculum + + +class SimpleBoxCurriculum(Curriculum): + """ + Base class and API for defining curricula to interface with Gym environments. + """ + REQUIRES_STEP_UPDATES = False + REQUIRES_EPISODE_UPDATES = False + REQUIRES_CENTRAL_UPDATES = False + + def __init__(self, + *curriculum_args, + steps: int = 5, + success_threshold: float = 0.25, + required_successes: int = 10, + **curriculum_kwargs): + super().__init__(*curriculum_args, **curriculum_kwargs) + assert isinstance(self.task_space.gym_space, Box), "SimpleBoxCurriculum only supports Box task spaces." + + self.success_threshold = success_threshold + self.required_successes = required_successes + + full_range = self.task_space.gym_space.high[1] - self.task_space.gym_space.low[0] + midpoint = self.task_space.gym_space.low[0] + (full_range / 2.0) + self.step_size = (full_range / 2.0) / steps + self.max_range = (midpoint - self.step_size, midpoint + self.step_size) + self.consecutive_successes = 0 + self.max_reached = False + + def update_task_progress(self, task: typing.Any, success_prob: float, env_id: int = None) -> None: + """ + Update the curriculum with a task and its success probability upon + success or failure. + """ + if self.max_reached: + return + + # Check if this task passed success threshold + if success_prob > self.success_threshold: + self.consecutive_successes += 1 + else: + self.consecutive_successes = 0 + + # If we have enough successes in a row, update task + if self.consecutive_successes >= self.required_successes: + new_low = max(self.max_range[0] - self.step_size, self.task_space.gym_space.low[0]) + new_high = min(self.max_range[1] + self.step_size, self.task_space.gym_space.high[1]) + self.max_range = (new_low, new_high) + self.consecutive_successes = 0 + if new_low == self.task_space.gym_space.low[0] and new_high == self.task_space.gym_space.high[1]: + self.max_reached = True + + def sample(self, k: int = 1) -> Union[List, Any]: + """ + Sample k tasks from the curriculum. + """ + return [self.max_range for _ in range(k)] diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/experimental/README b/examples/experimental/README new file mode 100644 index 00000000..1302681f --- /dev/null +++ b/examples/experimental/README @@ -0,0 +1 @@ +The training scripts in this folder are unlikely to run without errors, but may serve as useful references if you want to use Syllabus in a new RL library or environment. \ No newline at end of file diff --git a/examples/experimental/cleanrl_cartpole.py b/examples/experimental/cleanrl_cartpole.py new file mode 100644 index 00000000..ab71e478 --- /dev/null +++ b/examples/experimental/cleanrl_cartpole.py @@ -0,0 +1,326 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppopy + +import argparse +import os +import random +import time +from distutils.util import strtobool + +import gymnasium as gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +# Syllabus imports +from syllabus.core import MultiProcessingSyncWrapper, make_multiprocessing_curriculum +from syllabus.curricula import SimpleBoxCurriculum +from syllabus.examples.task_wrappers import CartPoleTaskWrapper +from syllabus.task_space import TaskSpace +from torch.distributions.categorical import Categorical +from torch.utils.tensorboard import SummaryWriter + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="Syllabus", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to capture videos of the agent performances (check out `videos` folder)") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="CartPole-v1", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=500000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=2.5e-4, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=4, + help="the number of parallel game environments") + parser.add_argument("--num-steps", type=int, default=128, + help="the number of steps to run in each environment per policy rollout") + parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--gamma", type=float, default=0.99, + help="the discount factor gamma") + parser.add_argument("--gae-lambda", type=float, default=0.95, + help="the lambda for the general advantage estimation") + parser.add_argument("--num-minibatches", type=int, default=4, + help="the number of mini-batches") + parser.add_argument("--update-epochs", type=int, default=4, + help="the K epochs to update the policy") + parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles advantages normalization") + parser.add_argument("--clip-coef", type=float, default=0.2, + help="the surrogate clipping coefficient") + parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--ent-coef", type=float, default=0.01, + help="coefficient of the entropy") + parser.add_argument("--vf-coef", type=float, default=0.5, + help="coefficient of the value function") + parser.add_argument("--max-grad-norm", type=float, default=0.5, + help="the maximum norm for the gradient clipping") + parser.add_argument("--target-kl", type=float, default=None, + help="the target KL divergence threshold") + args = parser.parse_args() + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + # fmt: on + return args + + +def make_env(env_id, seed, idx, capture_video, run_name, task_queue, update_queue): + def thunk(): + env = gym.make(env_id) + env = gym.wrappers.RecordEpisodeStatistics(env) + if capture_video: + if idx == 0: + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + # env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + env = CartPoleTaskWrapper(env) + env = MultiProcessingSyncWrapper(env, task_queue, update_queue, + task_space=TaskSpace(gym.spaces.Box(-0.3, 0.3, shape=(2,))), + update_on_step=False) + return env + + return thunk + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + +class Agent(nn.Module): + def __init__(self, envs): + super().__init__() + self.critic = nn.Sequential( + layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 1), std=1.0), + ) + self.actor = nn.Sequential( + layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 64)), + nn.Tanh(), + layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01), + ) + + def get_value(self, x): + return self.critic(x) + + def get_action_and_value(self, x, action=None): + logits = self.actor(x) + probs = Categorical(logits=logits) + if action is None: + action = probs.sample() + return action, probs.log_prob(action), probs.entropy(), self.critic(x) + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # curriculum setup + curriculum = SimpleBoxCurriculum(TaskSpace(gym.spaces.Box(-0.3, 0.3, shape=(2,)))) + curriculum, task_queue, update_queue = make_multiprocessing_curriculum(curriculum) + + # env setup + envs = gym.vector.AsyncVectorEnv( + [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name, task_queue, update_queue) for i in range(args.num_envs)] + ) + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + + agent = Agent(envs).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs, _ = envs.reset() + next_obs = torch.Tensor(next_obs).to(device) + next_done = torch.zeros(args.num_envs).to(device) + num_updates = args.total_timesteps // args.batch_size + + for update in range(1, num_updates + 1): + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (update - 1.0) / num_updates + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += 1 * args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value = agent.get_action_and_value(next_obs) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, term, trunc, infos = envs.step(action.cpu().numpy()) + done = np.logical_or(term, trunc) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) + + if "final_info" in infos: + for info in infos["final_info"]: + if info and "episode" in info: + print(f"global_step={global_step}, episodic_return={info['episode']['r']}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + curriculum.log_metrics(writer, step=global_step) + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None: + if approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + envs.close() + writer.close() diff --git a/examples/experimental/cleanrl_minigrid_plr.py b/examples/experimental/cleanrl_minigrid_plr.py new file mode 100644 index 00000000..83454aa9 --- /dev/null +++ b/examples/experimental/cleanrl_minigrid_plr.py @@ -0,0 +1,353 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppopy + +import argparse +import os +import random +import time +from distutils.util import strtobool + +import gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from syllabus.core import (MultiProcessingSyncWrapper, TaskWrapper, + make_multiprocessing_curriculum) +from syllabus.curricula import CentralizedPrioritizedLevelReplay +from syllabus.examples.models import MinigridAgent +from syllabus.examples.task_wrappers import MinigridTaskWrapper +from torch.utils.tensorboard import SummaryWriter + +from .vecenv import VecNormalize + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="syllabus", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to capture videos of the agent performances (check out `videos` folder)") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="MiniGrid-MultiRoom-N4-Random-v0", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=100000000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=7e-4, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=64, + help="the number of parallel game environments") + parser.add_argument("--num-steps", type=int, default=256, + help="the number of steps to run in each environment per policy rollout") + parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--gamma", type=float, default=0.999, + help="the discount factor gamma") + parser.add_argument("--gae-lambda", type=float, default=0.95, + help="the lambda for the general advantage estimation") + parser.add_argument("--num-minibatches", type=int, default=8, + help="the number of mini-batches") + parser.add_argument("--update-epochs", type=int, default=4, + help="the K epochs to update the policy") + parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles advantages normalization") + parser.add_argument("--clip-coef", type=float, default=0.2, + help="the surrogate clipping coefficient") + parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--ent-coef", type=float, default=0.01, + help="coefficient of the entropy") + parser.add_argument("--vf-coef", type=float, default=0.5, + help="coefficient of the value function") + parser.add_argument("--max-grad-norm", type=float, default=0.5, + help="the maximum norm for the gradient clipping") + parser.add_argument("--target-kl", type=float, default=None, + help="the target KL divergence threshold") + args = parser.parse_args() + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + # fmt: on + return args + + +def make_env(env_id, seed, idx, capture_video, run_name, task_queue, update_queue): + def thunk(): + env = gym.make(env_id) + #env = gym.wrappers.FlattenObservation(env) + env = gym.wrappers.RecordEpisodeStatistics(env) + env = MinigridTaskWrapper(env) + if capture_video: + if idx == 0: + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + # env = MultiProcessingSyncWrapper( + # env, + # task_queue, + # update_queue, + # update_on_step=False, + # default_task=0, + # task_space=env.task_space, + # ) + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + + return thunk + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" + % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # sample_env = gym.make(args.env_id) + # sample_env = MinigridTaskWrapper(sample_env) + # curriculum = PrioritizedLevelReplay( + # sample_env.task_space, task_sampler_kwargs_dict={"strategy":"one_step_td_error", "rho":0.01, "nu":0}, num_processes=args.num_envs, gamma=args.gamma, gae_lambda=args.gae_lambda, + # ) + # curriculum, task_queue, update_queue = make_multiprocessing_curriculum(curriculum) + # del sample_env + task_queue = update_queue = None + # env setup + envs = gym.vector.AsyncVectorEnv( + [ + make_env(args.env_id, args.seed + i, i, args.capture_video, run_name, task_queue, update_queue) + for i in range(args.num_envs) + ] + ) + envs = VecNormalize(envs, ob=False) + + assert isinstance( + envs.single_action_space, gym.spaces.Discrete + ), "only discrete action space is supported" + + agent = MinigridAgent(envs.single_observation_space.shape, envs.single_action_space.n).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros( + (args.num_steps, args.num_envs) + envs.single_observation_space.shape + ).to(device) + actions = torch.zeros( + (args.num_steps, args.num_envs) + envs.single_action_space.shape + ).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs = torch.Tensor(envs.reset()).to(device) + next_done = torch.zeros(args.num_envs).to(device) + num_updates = args.total_timesteps // args.batch_size + + for update in range(1, num_updates + 1): + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (update - 1.0) / num_updates + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += 1 * args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value, full_log_probs = agent.get_action_and_value( + next_obs, full_log_probs=True + ) + values[step] = value.flatten() + + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, done, info = envs.step(action.cpu().numpy()) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor( + done + ).to(device) + + for item in info: + if "episode" in item.keys(): + print( + f"global_step={global_step}, episodic_return={item['episode']['r']}" + ) + writer.add_scalar( + "charts/episodic_return", item["episode"]["r"], global_step + ) + writer.add_scalar( + "charts/episodic_length", item["episode"]["l"], global_step + ) + break + + # update = { + # "update_type": "on_demand", + # "metrics": { + # "action_log_dist": full_log_probs, + # "value": value, + # "next_value": agent.get_value(next_obs) + # if step == args.num_steps - 1 + # else None, + # "rew": reward, + # "masks": torch.Tensor(1 - done), + # "tasks": envs.get_attr("task"), + # }, + # } + # curriculum.update_curriculum(update) + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = ( + rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + ) + advantages[t] = lastgaelam = ( + delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + ) + returns = advantages + values + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value( + b_obs[mb_inds], b_actions.long()[mb_inds] + ) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [ + ((ratio - 1.0).abs() > args.clip_coef).float().mean().item() + ] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / ( + mb_advantages.std() + 1e-8 + ) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp( + ratio, 1 - args.clip_coef, 1 + args.clip_coef + ) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None: + if approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar( + "charts/learning_rate", optimizer.param_groups[0]["lr"], global_step + ) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar( + "charts/SPS", int(global_step / (time.time() - start_time)), global_step + ) + + envs.close() + writer.close() diff --git a/examples/experimental/cleanrl_minihack_plr.py b/examples/experimental/cleanrl_minihack_plr.py new file mode 100644 index 00000000..8f88ad5d --- /dev/null +++ b/examples/experimental/cleanrl_minihack_plr.py @@ -0,0 +1,358 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppopy + +import argparse +import os +import random +import time +from distutils.util import strtobool + +import gym +#import minigrid +import minihack +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from gym.envs.registration import register +from syllabus.core import (MultiProcessingSyncWrapper, TaskWrapper, + make_multiprocessing_curriculum) +from syllabus.curricula import CentralizedPrioritizedLevelReplay +from syllabus.examples import MinihackTaskWrapper +from torch.distributions.categorical import Categorical +from torch.utils.tensorboard import SummaryWriter + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="cleanRL", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to capture videos of the agent performances (check out `videos` folder)") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="MiniHack-River-v0", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=500000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=2.5e-4, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=4, + help="the number of parallel game environments") + parser.add_argument("--num-steps", type=int, default=128, + help="the number of steps to run in each environment per policy rollout") + parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--gamma", type=float, default=0.99, + help="the discount factor gamma") + parser.add_argument("--gae-lambda", type=float, default=0.95, + help="the lambda for the general advantage estimation") + parser.add_argument("--num-minibatches", type=int, default=4, + help="the number of mini-batches") + parser.add_argument("--update-epochs", type=int, default=4, + help="the K epochs to update the policy") + parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles advantages normalization") + parser.add_argument("--clip-coef", type=float, default=0.2, + help="the surrogate clipping coefficient") + parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--ent-coef", type=float, default=0.01, + help="coefficient of the entropy") + parser.add_argument("--vf-coef", type=float, default=0.5, + help="coefficient of the value function") + parser.add_argument("--max-grad-norm", type=float, default=0.5, + help="the maximum norm for the gradient clipping") + parser.add_argument("--target-kl", type=float, default=None, + help="the target KL divergence threshold") + args = parser.parse_args() + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + # fmt: on + return args + + +def make_env(env_id, seed, idx, capture_video, run_name, task_queue, update_queue): + def thunk(): + env = gym.make(env_id, observation_keys=("pixel", "glyphs")) + env = gym.wrappers.FlattenObservation(env) + env = gym.wrappers.RecordEpisodeStatistics(env) + env = MinihackTaskWrapper(env) + if capture_video: + if idx == 0: + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + env = MultiProcessingSyncWrapper(env, + task_queue, + update_queue, + update_on_step=False, + default_task=0, + task_space=gym.spaces.Discrete(1)) + #env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + + return thunk + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + +class Agent(nn.Module): + def __init__(self, envs): + super().__init__() + self.critic = nn.Sequential( + layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 1), std=1.0), + ) + self.actor = nn.Sequential( + layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), + nn.Tanh(), + layer_init(nn.Linear(64, 64)), + nn.Tanh(), + layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01), + ) + + def get_value(self, x): + return self.critic(x) + + def get_action_and_value(self, x, action=None, full_log_probs=False): + logits = self.actor(x) + probs = Categorical(logits=logits) + if action is None: + action = probs.sample() + if full_log_probs: + log_probs = torch.log(probs.probs) + return action, probs.log_prob(action), probs.entropy(), self.critic(x), log_probs + return action, probs.log_prob(action), probs.entropy(), self.critic(x) + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + sample_env = gym.make(args.env_id) + sample_env = MinihackTaskWrapper(sample_env) + + curriculum, task_queue, update_queue = make_multiprocessing_curriculum(CentralizedPrioritizedLevelReplay, + sample_env.task_space, + {"strategy": "one_step_td_error", + "rho": 0.01, + "nu": 0}, + action_space=sample_env.action_space, + num_steps=args.num_steps, + num_processes=args.num_envs, + gamma=args.gamma, + gae_lambda=args.gae_lambda, + random_start_tasks=0) + del sample_env + + # env setup + envs = gym.vector.AsyncVectorEnv( + [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name, task_queue, update_queue) for i in range(args.num_envs)] + ) + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + + agent = Agent(envs).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs = torch.Tensor(envs.reset()).to(device) + next_done = torch.zeros(args.num_envs).to(device) + num_updates = args.total_timesteps // args.batch_size + + for update in range(1, num_updates + 1): + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (update - 1.0) / num_updates + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += 1 * args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value, full_log_probs = agent.get_action_and_value(next_obs, full_log_probs=True) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, done, info = envs.step(action.cpu().numpy()) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) + + for item in info: + if "episode" in item.keys(): + print(f"global_step={global_step}, episodic_return={item['episode']['r']}") + writer.add_scalar("charts/episodic_return", item["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step) + break + + update = { + "update_type": "on_demand", + "metrics": { + "action_log_dist": full_log_probs, + "value": value, + "next_value": agent.get_value(next_obs) if step == args.num_steps - 1 else None, + "rew": reward, + "masks": torch.Tensor(1 - done), + "tasks": envs.get_attr("current_task"), + } + } + curriculum.update_curriculum(update) + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values + + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None: + if approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + envs.close() + writer.close() \ No newline at end of file diff --git a/examples/experimental/cleanrl_pettingzoo_pistonball_plr.py b/examples/experimental/cleanrl_pettingzoo_pistonball_plr.py new file mode 100644 index 00000000..2b416e21 --- /dev/null +++ b/examples/experimental/cleanrl_pettingzoo_pistonball_plr.py @@ -0,0 +1,334 @@ +"""Basic code which shows what it's like to run PPO on the Pistonball env using the parallel API, this code is inspired by CleanRL. + +This code is exceedingly basic, with no logging or weights saving. +The intention was for users to have a (relatively clean) ~200 line file to refer to when they want to design their own learning algorithm. + +Author: Jet (https://github.com/jjshoots) +""" + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from pettingzoo.butterfly import pistonball_v6 +from supersuit import color_reduction_v0, frame_stack_v1, resize_v1 +from syllabus.core import (PettingZooMultiProcessingSyncWrapper, TaskWrapper, + make_multiprocessing_curriculum) +from syllabus.curricula import CentralizedPrioritizedLevelReplay +from syllabus.examples import PistonballTaskWrapper +from torch.distributions.categorical import Categorical + + +class Agent(nn.Module): + def __init__(self, num_actions): + super().__init__() + + self.network = nn.Sequential( + self._layer_init(nn.Conv2d(4, 32, 3, padding=1)), + nn.MaxPool2d(2), + nn.ReLU(), + self._layer_init(nn.Conv2d(32, 64, 3, padding=1)), + nn.MaxPool2d(2), + nn.ReLU(), + self._layer_init(nn.Conv2d(64, 128, 3, padding=1)), + nn.MaxPool2d(2), + nn.ReLU(), + nn.Flatten(), + self._layer_init(nn.Linear(128 * 8 * 8, 512)), + nn.ReLU(), + ) + self.actor = self._layer_init(nn.Linear(512, num_actions), std=0.01) + self.critic = self._layer_init(nn.Linear(512, 1)) + + def _layer_init(self, layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + def get_value(self, x): + return self.critic(self.network(x / 255.0)) + + def get_action_and_value(self, x, action=None, full_log_probs=False): + hidden = self.network(x / 255.0) + logits = self.actor(hidden) + probs = Categorical(logits=logits) + if action is None: + action = probs.sample() + if full_log_probs: + log_probs = torch.log(probs.probs) + return action, probs.log_prob(action), probs.entropy(), self.critic(hidden), log_probs + return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) + + +def batchify_obs(obs, device): + """Converts PZ style observations to batch of torch arrays.""" + # convert to list of np arrays + obs = np.stack([obs[a] for a in obs], axis=0) + # transpose to be (batch, channel, height, width) + obs = obs.transpose(0, -1, 1, 2) + # convert to torch + obs = torch.tensor(obs).to(device) + + return obs + + +def batchify(x, device): + """Converts PZ style returns to batch of torch arrays.""" + # convert to list of np arrays + x = np.stack([x[a] for a in x], axis=0) + # convert to torch + x = torch.tensor(x).to(device) + + return x + + +def unbatchify(x, env): + """Converts np array to PZ style arguments.""" + x = x.cpu().numpy() + x = {a: x[i] for i, a in enumerate(env.possible_agents)} + + return x + + +if __name__ == "__main__": + """ALGO PARAMS""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ent_coef = 0.1 + vf_coef = 0.1 + clip_coef = 0.1 + gamma = 0.99 + batch_size = 32 + stack_size = 4 + frame_size = (64, 64) + max_cycles = 125 + total_episodes = 100 + + # PLR settings + num_steps = 128 + + """ CURRICULUM SETUP """ + sample_env = pistonball_v6.parallel_env( + continuous=False, max_cycles=max_cycles + ) + sample_env = PistonballTaskWrapper(sample_env) + + num_agents = len(sample_env.possible_agents) + action_space = sample_env.action_spaces[sample_env.possible_agents[0]] + num_actions = action_space.n + observation_size = sample_env.observation_space(sample_env.possible_agents[0]).shape + + curriculum, task_queue, update_queue = make_multiprocessing_curriculum(CentralizedPrioritizedLevelReplay, + sample_env.task_space, + task_sampler_kwargs_dict={"strategy": "one_step_td_error", + "rho": 0.01, + "nu": 0}, + action_space=action_space, + num_steps=num_steps, + num_processes=num_agents, + gamma=gamma, + gae_lambda=0.95, + random_start_tasks=0) + del sample_env + + """ ENV SETUP """ + env = pistonball_v6.parallel_env( + continuous=False, max_cycles=max_cycles + ) + env = PistonballTaskWrapper(env) + env = PettingZooMultiProcessingSyncWrapper(env, + task_queue, + update_queue, + update_on_step=False, + default_task=0, + task_space=env.task_space) + task_space = env.task_space + env = color_reduction_v0(env) + env = resize_v1(env, frame_size[0], frame_size[1]) + env = frame_stack_v1(env, stack_size=stack_size) + num_agents = len(env.possible_agents) + action_space = env.action_spaces[env.possible_agents[0]] + num_actions = action_space.n + observation_size = env.observation_space(env.possible_agents[0]).shape + + """ LEARNER SETUP """ + agent = Agent(num_actions=num_actions).to(device) + optimizer = optim.Adam(agent.parameters(), lr=0.001, eps=1e-5) + + """ ALGO LOGIC: EPISODE STORAGE""" + end_step = 0 + total_episodic_return = 0 + rb_obs = torch.zeros((max_cycles, num_agents, stack_size, *frame_size)).to(device) + rb_actions = torch.zeros((max_cycles, num_agents)).to(device) + rb_logprobs = torch.zeros((max_cycles, num_agents)).to(device) + rb_rewards = torch.zeros((max_cycles, num_agents)).to(device) + rb_dones = torch.zeros((max_cycles, num_agents)).to(device) + rb_values = torch.zeros((max_cycles, num_agents)).to(device) + + """ TRAINING LOGIC """ + # train for n number of episodes + global_cycles = 0 + for episode in range(total_episodes): + # collect an episode + with torch.no_grad(): + # collect observations and convert to batch of torch tensors + next_obs = env.reset() + # reset the episodic return + total_episodic_return = 0 + + # each episode has num_steps + for step in range(0, max_cycles): + global_cycles += 1 + # rollover the observation + obs = batchify_obs(next_obs, device) + + # get action from the agent + actions, logprobs, _, values, full_log_probs = agent.get_action_and_value(obs, full_log_probs=True) + + # execute the environment and log data + next_obs, rewards, dones, infos = env.step( + unbatchify(actions, env) + ) + + # add to episode storage + rb_obs[step] = obs + rb_rewards[step] = batchify(rewards, device) + rb_dones[step] = batchify(dones, device) + rb_actions[step] = actions + rb_logprobs[step] = logprobs + rb_values[step] = values.flatten() + + # compute episodic return + total_episodic_return += rb_rewards[step].cpu().numpy() + + # Update curriculum + if global_cycles % num_steps == 0: + update = { + "update_type": "on_demand", + "metrics": { + "action_log_dist": full_log_probs, + "value": values, + "next_value": agent.get_value(next_obs) if step == num_steps - 1 else None, + "rew": rb_rewards[step], + "masks": torch.Tensor(1 - np.array(list(dones.values()))), + "tasks": [env.unwrapped.task], + } + } + curriculum.update_curriculum(update) + + # if we reach the end of the episode + if any([dones[a] for a in dones]): + end_step = step + break + + + # bootstrap value if not done + with torch.no_grad(): + rb_advantages = torch.zeros_like(rb_rewards).to(device) + for t in reversed(range(end_step)): + delta = ( + rb_rewards[t] + + gamma * rb_values[t + 1] * rb_dones[t + 1] + - rb_values[t] + ) + rb_advantages[t] = delta + gamma * gamma * rb_advantages[t + 1] + rb_returns = rb_advantages + rb_values + + # convert our episodes to batch of individual transitions + b_obs = torch.flatten(rb_obs[:end_step], start_dim=0, end_dim=1) + b_logprobs = torch.flatten(rb_logprobs[:end_step], start_dim=0, end_dim=1) + b_actions = torch.flatten(rb_actions[:end_step], start_dim=0, end_dim=1) + b_returns = torch.flatten(rb_returns[:end_step], start_dim=0, end_dim=1) + b_values = torch.flatten(rb_values[:end_step], start_dim=0, end_dim=1) + b_advantages = torch.flatten(rb_advantages[:end_step], start_dim=0, end_dim=1) + + # Optimizing the policy and value network + b_index = np.arange(len(b_obs)) + clip_fracs = [] + for repeat in range(3): + # shuffle the indices we use to access the data + np.random.shuffle(b_index) + for start in range(0, len(b_obs), batch_size): + # select the indices we want to train on + end = start + batch_size + batch_index = b_index[start:end] + + _, newlogprob, entropy, value = agent.get_action_and_value( + b_obs[batch_index], b_actions.long()[batch_index] + ) + logratio = newlogprob - b_logprobs[batch_index] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clip_fracs += [ + ((ratio - 1.0).abs() > clip_coef).float().mean().item() + ] + + # normalize advantaegs + advantages = b_advantages[batch_index] + advantages = (advantages - advantages.mean()) / ( + advantages.std() + 1e-8 + ) + + # Policy loss + pg_loss1 = -b_advantages[batch_index] * ratio + pg_loss2 = -b_advantages[batch_index] * torch.clamp( + ratio, 1 - clip_coef, 1 + clip_coef + ) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + value = value.flatten() + v_loss_unclipped = (value - b_returns[batch_index]) ** 2 + v_clipped = b_values[batch_index] + torch.clamp( + value - b_values[batch_index], + -clip_coef, + clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[batch_index]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + + entropy_loss = entropy.mean() + loss = pg_loss - ent_coef * entropy_loss + v_loss * vf_coef + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + print(f"Training episode {episode}") + print(f"Episodic Return: {np.mean(total_episodic_return)}") + print(f"Episode Length: {end_step}") + print("") + print(f"Value Loss: {v_loss.item()}") + print(f"Policy Loss: {pg_loss.item()}") + print(f"Old Approx KL: {old_approx_kl.item()}") + print(f"Approx KL: {approx_kl.item()}") + print(f"Clip Fraction: {np.mean(clip_fracs)}") + print(f"Explained Variance: {explained_var.item()}") + print("\n-------------------------------------------\n") + + """ RENDER THE POLICY """ + env = pistonball_v6.parallel_env(continuous=False) + env = color_reduction_v0(env) + env = resize_v0(env, 64, 64) + env = frame_stack_v1(env, stack_size=4) + + agent.eval() + + with torch.no_grad(): + # render 5 episodes out + for episode in range(5): + obs = batchify_obs(env.reset(), device) + dones = [False] + while not any(dones): + actions, logprobs, _, values = agent.get_action_and_value(obs) + obs, rewards, dones, infos = env.step(unbatchify(actions, env)) + obs = batchify_obs(obs, device) + dones = [dones[a] for a in dones] diff --git a/examples/experimental/dormant_neurons.py b/examples/experimental/dormant_neurons.py new file mode 100644 index 00000000..ca2eb21c --- /dev/null +++ b/examples/experimental/dormant_neurons.py @@ -0,0 +1,764 @@ +"""This class implements recycling of dead neurons.""" +import functools +import logging + +import flax +import jax +import jax.numpy as jnp +import optax +from flax import linen as flax_nn +from jax import random + + +def leastk_mask(scores, ones_fraction): + """Given a tensor of scores creates a binary mask. + + Args: + scores: top-scores are kept + ones_fraction: float, of the generated mask. + + Returns: + array, same shape and type as scores or None. + """ + if ones_fraction is None or ones_fraction == 0: + return jnp.zeros_like(scores) + # This is to ensure indices with smallest values are selected. + scores = -scores + + n_ones = jnp.round(jnp.size(scores) * ones_fraction) + k = jnp.maximum(1, n_ones).astype(int) + flat_scores = jnp.reshape(scores, -1) + threshold = jax.lax.sort(flat_scores)[-k] + + mask = (flat_scores >= threshold).astype(flat_scores.dtype) + return mask.reshape(scores.shape) + + +def reset_momentum(momentum, mask): + new_momentum = momentum if mask is None else momentum * (1.0 - mask) + return new_momentum + + +def weight_reinit_zero(param, mask): + if mask is None: + return param + else: + new_param = jnp.zeros_like(param) + param = jnp.where(mask == 1, new_param, param) + return param + + +def weight_reinit_random(param, mask, key, weight_scaling=False, scale=1.0, weights_type='incoming'): + """Randomly reinit recycled weights and may scale its norm. + + If scaling applied, the norm of recycled weights equals + the average norm of non recycled weights per neuron multiplied by a scalar. + + Args: + param: current param + mask: incoming/outgoing mask for recycled weights + key: random key to generate new random weights + weight_scaling: if true scale recycled weights with the norm of non recycled + scale: scale to multiply the new weights norm. + weights_type: incoming or outgoing weights + + Returns: + params: new params after weight recycle. + """ + if mask is None or key is None: + return param + + new_param = flax_nn.initializers.xavier_uniform()(key, shape=param.shape) + + if weight_scaling: + axes = list(range(param.ndim)) + if weights_type == 'outgoing': + del axes[-2] + else: + del axes[-1] + + neuron_mask = jnp.mean(mask, axis=axes) + + non_dead_count = neuron_mask.shape[0] - jnp.count_nonzero(neuron_mask) + norm_per_neuron = _get_norm_per_neuron(param, axes) + non_recycled_norm = jnp.sum(norm_per_neuron * + (1 - neuron_mask)) / non_dead_count + non_recycled_norm = non_recycled_norm * scale + + normalized_new_param = _weight_normalization_per_neuron_norm( + new_param, axes) + new_param = normalized_new_param * non_recycled_norm + + param = jnp.where(mask == 1, new_param, param) + return param + + +def _weight_normalization_per_neuron_norm(param, axes): + norm_per_neuron = _get_norm_per_neuron(param, axes) + norm_per_neuron = jnp.expand_dims(norm_per_neuron, axis=axes) + normalized_param = param / norm_per_neuron + return normalized_param + + +def _get_norm_per_neuron(param, axes): + return jnp.sqrt(jnp.sum(jnp.power(param, 2), axis=axes)) + + +class BaseRecycler(): + """Base class for weight update methods. + + Attributes: + all_layers_names: list of layer names in a model. + recycle_type: neuron, layer based. + dead_neurons_threshold: below this threshold a neuron is considered dead. + reset_layers: list of layer names to be recycled. + reset_start_layer_idx: index of the layer from which we start recycling. + reset_period: int represents the period of weight update. + reset_start_step: start recycle from start step + reset_end_step: end recycle from end step + logging_period: the period of statistics logging e.g., dead neurons. + prev_neuron_score: score at last reset step or log step in case of no reset. + sub_mean_score: if True the average activation will be subtracted + for each neuron when we calculate the score. + """ + + def __init__( + self, + all_layers_names, + dead_neurons_threshold=0.0, + reset_start_layer_idx=0, + reset_period=200_000, + reset_start_step=0, + reset_end_step=100_000_000, + logging_period=20_000, + sub_mean_score=False, + ): + self.all_layers_names = all_layers_names + self.dead_neurons_threshold = dead_neurons_threshold + self.reset_layers = all_layers_names[reset_start_layer_idx:] + self.reset_period = reset_period + self.reset_start_step = reset_start_step + self.reset_end_step = reset_end_step + self.logging_period = logging_period + self.prev_neuron_score = None + self.sub_mean_score = sub_mean_score + + def update_reset_layers(self, reset_start_layer_idx): + self.reset_layers = self.all_layers_names[reset_start_layer_idx:] + + def is_update_iter(self, step): + return step > 0 and (step % self.reset_period == 0) + + def update_weights(self, intermediates, params, key, opt_state): + raise NotImplementedError + + def maybe_update_weights(self, update_step, intermediates, params, key, opt_state): + self._last_update_step = update_step + if self.is_reset(update_step): + new_params, new_opt_state = self.update_weights(intermediates, params, key, opt_state) + else: + new_params, new_opt_state = params, opt_state + return new_params, new_opt_state + + def is_reset(self, update_step): + del update_step + return False + + def is_intermediated_required(self, update_step): + return self.is_logging_step(update_step) + + def is_logging_step(self, step): + return step % self.logging_period == 0 + + def maybe_log_deadneurons(self, update_step, intermediates): + is_logging = self.is_logging_step(update_step) + if is_logging: + return self.log_dead_neurons_count(intermediates) + else: + return None + + def intersected_dead_neurons_with_last_reset(self, intermediates, update_step): + if self.is_logging_step(update_step): + log_dict = self.log_intersected_dead_neurons(intermediates) + return log_dict + else: + return None + + def log_intersected_dead_neurons(self, intermediates): + """Track intersected dead neurons with last logging/reset step. + + Args: + intermediates: current intermediates + + Returns: + log_dict: dict contains the percentage of intersection + """ + score_tree = jax.tree_map(self.estimate_neuron_score, intermediates) + neuron_score_dict = flax.traverse_util.flatten_dict(score_tree, sep='/') + + if self.prev_neuron_score is None: + self.prev_neuron_score = neuron_score_dict + log_dict = None + else: + log_dict = {} + for prev_k_score, current_k_score in zip(self.prev_neuron_score.items(), + neuron_score_dict.items()): + _, prev_score = prev_k_score + k, score = current_k_score + prev_score, score = prev_score[0], score[0] + prev_mask = prev_score <= self.dead_neurons_threshold + # we count the dead neurons which remains dead in the current step. + intersected_mask = (prev_mask) & (score <= self.dead_neurons_threshold) + prev_dead_count = jnp.count_nonzero(prev_mask) + intersected_count = jnp.count_nonzero(intersected_mask) + + percent = (float(intersected_count) / + prev_dead_count) if prev_dead_count else 0.0 + log_dict[f'dead_intersected_percent/{k[:-9]}'] = float(percent) * 100. + + # track average score of recycled neurons from last step and + # the average of non dead in the current step. + nondead_mask = score > self.dead_neurons_threshold + log_dict[f'mean_score_recycled/{k[:-9]}'] = float( + jnp.mean(score[prev_mask])) + log_dict[f'mean_score_nondead/{k[:-9]}'] = float( + jnp.mean(score[nondead_mask])) + + self.prev_neuron_score = neuron_score_dict + return log_dict + + def log_dead_neurons_count(self, intermediates): + """log dead neurons in each layer. + + For conv layer we also log dead elements in the spatial dimension. + + Args: + intermediates: intermidate activation in each layer. + + Returns: + log_dict_elements_per_neuron + log_dict_neurons + """ + + def log_dict(score, score_type): + total_neurons, total_deadneurons = 0., 0. + score_dict = flax.traverse_util.flatten_dict(score, sep='/') + + log_dict = {} + for k, m in score_dict.items(): + if 'final_layer' in k: + continue + m = m[0] + layer_size = float(jnp.size(m)) + deadneurons_count = jnp.count_nonzero(m <= self.dead_neurons_threshold) + total_neurons += layer_size + total_deadneurons += deadneurons_count + log_dict[f'dead_{score_type}_percentage/{k[:-9]}'] = ( + float(deadneurons_count) / layer_size) * 100. + log_dict[f'dead_{score_type}_count/{k[:-9]}'] = float(deadneurons_count) + log_dict[f'{score_type}/total'] = total_neurons + log_dict[f'{score_type}/deadcount'] = float(total_deadneurons) + log_dict[f'dead_{score_type}_percentage'] = (float(total_deadneurons) / + total_neurons) * 100. + return log_dict + + neuron_score = jax.tree_map(self.estimate_neuron_score, intermediates) + log_dict_neurons = log_dict(neuron_score, 'feature') + + return log_dict_neurons + + def estimate_neuron_score(self, activation, is_cbp=False): + """Calculates neuron score based on absolute value of activation. + + The score of feature map is the normalized average score over + the spatial dimension. + + Args: + activation: intermediate activation of each layer + is_cbp: if true, subtracts the mean and skips normalization. + + Returns: + element_score: score of each element in feature map in the spatial dim. + neuron_score: score of feature map + """ + reduce_axes = list(range(activation.ndim - 1)) + if self.sub_mean_score or is_cbp: + activation = activation - jnp.mean(activation, axis=reduce_axes) + + score = jnp.mean(jnp.abs(activation), axis=reduce_axes) + if not is_cbp: + # Normalize so that all scores sum to one. + score /= jnp.mean(score) + 1e-9 + + return score + + +class NeuronRecycler(BaseRecycler): + """Recycle the weights connected to dead neurons. + + In convolutional neural networks, we consider a feature map as neuron. + + Attributes: + next_layers: dict key a current layer name, value next layer name. + init_method_outgoing: method to init outgoing weights (random, zero). + weight_scaling: if true, scale reinit weights. + incoming_scale: scalar for incoming weights. + outgoing_scale: scalar for outgoing weights. + """ + + def __init__(self, + all_layers_names, + init_method_outgoing='zero', + weight_scaling=False, + incoming_scale=1.0, + outgoing_scale=1.0, + network='nature', + **kwargs): + super(NeuronRecycler, self).__init__(all_layers_names, **kwargs) + self.init_method_outgoing = init_method_outgoing + self.weight_scaling = weight_scaling + self.incoming_scale = incoming_scale + self.outgoing_scale = outgoing_scale + # prepare a dict that has pointer to next layer give a layer name + # this is needed because neuron recycle reinitalizes both sides + # (incoming and outgoing weights) of a neuron and needs a point to the + # outgoing weights. + self.next_layers = {} + for current_layer, next_layer in zip(all_layers_names[:-1], + all_layers_names[1:]): + self.next_layers[current_layer] = next_layer + + # we don't recycle the neurons in the output layer. + self.reset_layers = self.reset_layers[:-1] + + # if network is resnet, recycle only the incoming/outgoing of the first conv + # layer in each block and final dense layer + if network == 'resnet': + self.reset_layers = [] + for layer in self.all_layers_names: + if 'Conv_1' in layer or 'Conv_3' in layer or 'Dense' in layer: + self.reset_layers.append(layer) + + def intersected_dead_neurons_with_last_reset(self, intermediates, + update_step): + if self.is_reset(update_step): + log_dict = self.log_intersected_dead_neurons(intermediates) + return log_dict + else: + return None + + def is_reset(self, update_step): + within_reset_interval = ( + update_step >= self.reset_start_step and + update_step < self.reset_end_step) + return self.is_update_iter(update_step) and within_reset_interval + + def is_intermediated_required(self, update_step): + is_logging = self.is_logging_step(update_step) + is_update_iter = self.is_update_iter(update_step) + return is_logging or is_update_iter + + def update_reset_layers(self, reset_start_layer_idx): + self.reset_layers = self.all_layers_names[reset_start_layer_idx:] + self.reset_layers = self.reset_layers[:-1] + + def update_weights(self, intermediates, params, key, opt_state): + new_param, opt_state = self.recycle_dead_neurons( + intermediates, params, key, opt_state + ) + return new_param, opt_state + + def recycle_dead_neurons(self, intermedieates, params, key, opt_state): + """Recycle dead neurons by reinitalizie incoming and outgoing connections. + + Incoming connections are randomly initalized and outgoing connections + are zero initalized. + A featuremap is considered dead when its score is below or equal + dead neuron threshold. + Args: + intermedieates: pytree contains the activations over a batch. + params: current weights of the model. + key: used to generate random keys. + opt_state: state of optimizer. + + Returns: + new model params after recycling dead neurons. + opt_state: new state for the optimizer + + Raises: raise error if init_method_outgoing is not one of the following + (random, zero). + """ + activations_score_dict = flax.traverse_util.flatten_dict( + intermedieates, sep='/' + ) + param_dict = flax.traverse_util.flatten_dict(params, sep='/') + + # create incoming and outgoing masks and reset bias of dead neurons. + ( + incoming_mask_dict, + outgoing_mask_dict, + incoming_random_keys_dict, + outgoing_random_keys_dict, + param_dict, + ) = self.create_masks(param_dict, activations_score_dict, key) + + params = flax.core.freeze( + flax.traverse_util.unflatten_dict(param_dict, sep='/')) + incoming_random_keys = flax.core.freeze( + flax.traverse_util.unflatten_dict(incoming_random_keys_dict, sep='/')) + if self.init_method_outgoing == 'random': + outgoing_random_keys = flax.core.freeze( + flax.traverse_util.unflatten_dict(outgoing_random_keys_dict, sep='/')) + # reset incoming weights + incoming_mask = flax.core.freeze( + flax.traverse_util.unflatten_dict(incoming_mask_dict, sep='/')) + reinit_fn = functools.partial( + weight_reinit_random, + weight_scaling=self.weight_scaling, + scale=self.incoming_scale, + weights_type='incoming') + weight_random_reset_fn = jax.jit(functools.partial(jax.tree_map, reinit_fn)) + params = weight_random_reset_fn(params, incoming_mask, incoming_random_keys) + + # reset outgoing weights + outgoing_mask = flax.core.freeze( + flax.traverse_util.unflatten_dict(outgoing_mask_dict, sep='/')) + if self.init_method_outgoing == 'random': + reinit_fn = functools.partial( + weight_reinit_random, + weight_scaling=self.weight_scaling, + scale=self.outgoing_scale, + weights_type='outgoing') + weight_random_reset_fn = jax.jit( + functools.partial(jax.tree_map, reinit_fn)) + params = weight_random_reset_fn(params, outgoing_mask, + outgoing_random_keys) + elif self.init_method_outgoing == 'zero': + weight_zero_reset_fn = jax.jit( + functools.partial(jax.tree_map, weight_reinit_zero)) + params = weight_zero_reset_fn(params, outgoing_mask) + else: + raise ValueError(f'Invalid init method: {self.init_method_outgoing}') + # reset mu, nu of adam optimizer for recycled weights. + reset_momentum_fn = jax.jit(functools.partial(jax.tree_map, reset_momentum)) + new_mu = reset_momentum_fn(opt_state[0][1], incoming_mask) + new_mu = reset_momentum_fn(new_mu, outgoing_mask) + new_nu = reset_momentum_fn(opt_state[0][2], incoming_mask) + new_nu = reset_momentum_fn(new_nu, outgoing_mask) + opt_state_list = list(opt_state) + opt_state_list[0] = optax.ScaleByAdamState( + opt_state[0].count, mu=new_mu, nu=new_nu) + opt_state = tuple(opt_state_list) + return params, opt_state + + def _score2mask(self, activation, param, next_param, key): + del key, param, next_param + score = self.estimate_neuron_score(activation) + return score <= self.dead_neurons_threshold + + def create_masks(self, param_dict, activations_dict, key): + """create the masks for recycled weights based on neurons scores. + + Args: + param_dict: dict of model params. + activations_dict: dict of the neuron score of each layer. + key: used seed for random weights. + + Returns: + incoming_mask_dict + outgoing_mask_dict + ingoing_random_keys_dict + outgoing_random_keys_dict + param_dict + """ + incoming_mask_dict = { + k: jnp.zeros_like(p) if p.ndim != 1 else None + for k, p in param_dict.items() + } + outgoing_mask_dict = { + k: jnp.zeros_like(p) if p.ndim != 1 else None + for k, p in param_dict.items() + } + ingoing_random_keys_dict = {k: None for k in param_dict} + outgoing_random_keys_dict = { + k: None for k in param_dict + } if self.init_method_outgoing == 'random' else {} + + # prepare mask of incoming and outgoing recycled connections + for k in self.reset_layers: + param_key = 'params/' + k + '/kernel' + param = param_dict[param_key] + # This won't work for DRQ, since returned keys can be a list. + # We don't support that at the moment. + next_key = self.next_layers[k] + if isinstance(next_key, list): + next_key = next_key[0] + next_param = param_dict['params/' + next_key + '/kernel'] + activation = activations_dict[k + '_act/__call__'][0] + # TODO(evcu) Maybe use per_layer random keys here. + neuron_mask = self._score2mask(activation, param, next_param, key) + + # the for loop handles the case where a layer has multiple next layers + # like the case in DrQ where the output layer has multihead. + next_keys = ( + self.next_layers[k] + if isinstance(self.next_layers[k], list) else [self.next_layers[k]]) + for next_k in next_keys: + next_param_key = 'params/' + next_k + '/kernel' + next_param = param_dict[next_param_key] + incoming_mask, outgoing_mask = self.create_mask_helper( + neuron_mask, param, next_param) + incoming_mask_dict[param_key] = incoming_mask + outgoing_mask_dict[next_param_key] = outgoing_mask + key, subkey = random.split(key) + ingoing_random_keys_dict[param_key] = subkey + if self.init_method_outgoing == 'random': + key, subkey = random.split(key) + outgoing_random_keys_dict[next_param_key] = subkey + + # reset bias + bias_key = 'params/' + k + '/bias' + new_bias = jnp.zeros_like(param_dict[bias_key]) + param_dict[bias_key] = jnp.where(neuron_mask, new_bias, + param_dict[bias_key]) + + return (incoming_mask_dict, outgoing_mask_dict, ingoing_random_keys_dict, + outgoing_random_keys_dict, param_dict) + + def create_mask_helper(self, neuron_mask, current_param, next_param): + """generate incoming and outgoing weight mask given dead neurons mask. + + Args: + neuron_mask: mask of size equals the width of a layer. + current_param: incoming weights of a layer. + next_param: outgoing weights of a layer. + + Returns: + incoming_mask + outgoing_mask + """ + + def mask_creator(expansion_axis, expansion_axes, param, neuron_mask): + """create a mask of weight matrix given 1D vector of neurons mask. + + Args: + expansion_axis: List contains 1 axis. The dimension to expand the mask + for dense layers (weight shape 2D). + expansion_axes: List conrtains 3 axes. The dimensions to expand the + score for convolutional layers (weight shape 4D). + param: weight. + neuron_mask: 1D mask that represents dead neurons(features). + + Returns: + mask: mask of weight. + """ + if param.ndim == 2: + axes = expansion_axis + # flatten layer + # The size of neuron_mask is the same as the width of last conv layer. + # This conv layer will be flatten and connected to dense layer. + # we repeat each value of a feature map to cover the spatial dimension. + if axes[0] == 1 and (param.shape[0] > neuron_mask.shape[0]): + num_repeatition = int(param.shape[0] / neuron_mask.shape[0]) + neuron_mask = jnp.repeat(neuron_mask, num_repeatition, axis=0) + elif param.ndim == 4: + axes = expansion_axes + mask = jnp.expand_dims(neuron_mask, axis=tuple(axes)) + for i in range(len(axes)): + mask = jnp.repeat(mask, param.shape[axes[i]], axis=axes[i]) + return mask + + incoming_mask = mask_creator([0], [0, 1, 2], current_param, neuron_mask) + outgoing_mask = mask_creator([1], [0, 1, 3], next_param, neuron_mask) + return incoming_mask, outgoing_mask + + +class NeuronRecyclerScheduled(NeuronRecycler): + """Fixed scheduled version of the NeuronRecycler.""" + + def __init__( + self, + *args, + score_type='redo', + recycle_rate=0.3, + **kwargs, + ): + super(NeuronRecyclerScheduled, self).__init__(*args, **kwargs) + self.score_type = score_type + self.recycle_rate = recycle_rate + + def _score2mask(self, activation, param, next_param, key): + is_cbp = self.score_type == 'cbp' + score = self.estimate_neuron_score(activation, is_cbp=is_cbp) + if self.score_type == 'redo': + pass + elif self.score_type == 'random': + new_key = random.fold_in(key, self._last_update_step) + score = random.shuffle(new_key, score) + elif self.score_type == 'redo_inverted': + score = -score + # Metric used in Continual Backprop pape. + elif self.score_type == 'cbp': + next_axes = list(range(param.ndim)) + del next_axes[-2] + current_axes = list(range(param.ndim)) + del current_axes[-1] + if next_param.ndim == 2 and param.ndim == 4: + new_shape = activation.shape[1:] + (-1,) + next_param = jnp.reshape(next_param, new_shape) + score *= jnp.sum(jnp.abs(next_param), axis=next_axes) / jnp.sum( + jnp.abs(param), axis=current_axes + ) + multiplier = max(0, self._last_update_step / self.reset_end_step) + ones_fraction = float(jnp.cos(jnp.pi * 0.5 * multiplier)) + ones_fraction *= self.recycle_rate + return leastk_mask(score, ones_fraction) + + + +if __name__ == "__main__": + import gym + import numpy as np + import procgen + import torch + import torch.nn as nn + import torch.optim as optim + from torch.distributions.categorical import Categorical + + def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + # taken from https://github.com/AIcrowd/neurips2020-procgen-starter-kit/blob/142d09586d2272a17f44481a115c4bd817cf6a94/models/impala_cnn_torch.py + class ResidualBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv0 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1) + self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1) + + def forward(self, x): + inputs = x + x = nn.functional.relu(x) + x = self.conv0(x) + x = nn.functional.relu(x) + x = self.conv1(x) + return x + inputs + + class ConvSequence(nn.Module): + def __init__(self, input_shape, out_channels): + super().__init__() + self._input_shape = input_shape + self._out_channels = out_channels + self.conv = nn.Conv2d(in_channels=self._input_shape[0], out_channels=self._out_channels, kernel_size=3, padding=1) + self.res_block0 = ResidualBlock(self._out_channels) + self.res_block1 = ResidualBlock(self._out_channels) + + def forward(self, x): + x = self.conv(x) + x = nn.functional.max_pool2d(x, kernel_size=3, stride=2, padding=1) + x = self.res_block0(x) + x = self.res_block1(x) + assert x.shape[1:] == self.get_output_shape() + return x + + def get_output_shape(self): + _c, h, w = self._input_shape + return (self._out_channels, (h + 1) // 2, (w + 1) // 2) + + class Agent(nn.Module): + def __init__(self, envs): + super().__init__() + h, w, c = envs.single_observation_space.shape + shape = (c, h, w) + conv_seqs = [] + for out_channels in [16, 32, 32]: + conv_seq = ConvSequence(shape, out_channels) + shape = conv_seq.get_output_shape() + conv_seqs.append(conv_seq) + conv_seqs += [ + nn.Flatten(), + nn.ReLU(), + nn.Linear(in_features=shape[0] * shape[1] * shape[2], out_features=256), + nn.ReLU(), + ] + self.network = nn.Sequential(*conv_seqs) + self.actor = layer_init(nn.Linear(256, envs.single_action_space.n), std=0.01) + self.critic = layer_init(nn.Linear(256, 1), std=1) + + def get_value(self, x): + return self.critic(self.network(x.permute((0, 3, 1, 2)) / 255.0)) # "bhwc" -> "bchw" + + def get_action_and_value(self, x, action=None, full_log_probs=False): + hidden = self.network(x.permute((0, 3, 1, 2)) / 255.0) # "bhwc" -> "bchw" + value = self.critic(hidden) + logits = self.actor(hidden) + dist = Categorical(logits=logits) + if action is None: + action = dist.sample() + + action_log_probs = torch.squeeze(dist.log_prob(action)) + dist_entropy = dist.entropy() + + if full_log_probs: + log_probs = torch.log(dist.probs) + return action, action_log_probs, dist_entropy, value, log_probs + + return action, action_log_probs, dist_entropy, value + + def make_env(env_id, seed): + def thunk(): + env = gym.make(f"procgen-{env_id}-v0", rand_seed=seed, distribution_mode="easy") + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + return thunk + + envs = gym.vector.SyncVectorEnv( + [ + make_env("bigfish", 1 + i) + for i in range(1) + ] + ) + agent = Agent(envs).to("cuda") + optimizer = optim.Adam(agent.parameters(), lr=5e-4, eps=1e-5) + + online_params = [] + for name, module in agent.named_children(): + if not name.startswith('params'): + online_params.append(name) + weight_recycler = NeuronRecycler(online_params) + rng = jax.random.PRNGKey(1) + + for update_step in range(10000): + # Neuron/layer recycling starts if reset_mode is not None. + # Otherwise, we log dead neurons over training for standard agent. + is_intermediated = weight_recycler.is_intermediated_required(update_step) + + # get intermediate activation per layer to calculate neuron score. + def get_intermediates(online_params): + batch = self._sample_batch_for_statistics() + + def apply_data(x): + filter_rep = lambda l, _: l.name is not None and 'act' in l.name + return self.network_def.apply( + online_params, + x, + capture_intermediates=filter_rep, + mutable=['intermediates']) + + _, state = jax.vmap(apply_data)(batch) + return state['intermediates'] + + intermediates = get_intermediates(agent.parameters()) if is_intermediated else None + log_dict_neurons = (weight_recycler.maybe_log_deadneurons(update_step, intermediates)) + + # Neuron/layer recyling. + rng, key = jax.random.split(rng) + online_params, opt_state = weight_recycler.maybe_update_weights( + update_step, intermediates, online_params, key, optimizer.state_dict() + ) + optimizer.load_state_dict(opt_state) + with torch.no_grad(): + for agent_param, new_param in zip(agent.parameters().values(), online_params.values()): + agent_param.data = new_param diff --git a/examples/experimental/rllib_cartpole.py b/examples/experimental/rllib_cartpole.py new file mode 100644 index 00000000..8877a279 --- /dev/null +++ b/examples/experimental/rllib_cartpole.py @@ -0,0 +1,40 @@ +import gymnasium as gym +from gymnasium.spaces import Box +from ray import tune +from ray.tune.registry import register_env +from syllabus.core import RaySyncWrapper, make_ray_curriculum +from syllabus.curricula import SimpleBoxCurriculum +from syllabus.task_space import TaskSpace + +from .task_wrappers import CartPoleTaskWrapper + +# Define a task space +if __name__ == "__main__": + task_space = TaskSpace(Box(-0.3, 0.3, shape=(2,))) + + def env_creator(config): + env = gym.make("CartPole-v1") + # Wrap the environment to change tasks on reset() + env = CartPoleTaskWrapper(env) + # Add environment sync wrapper + env = RaySyncWrapper( + env, task_space=task_space, update_on_step=False + ) + return env + + register_env("task_cartpole", env_creator) + + # Create the curriculum + curriculum = SimpleBoxCurriculum(task_space) + # Add the curriculum sync wrapper + curriculum = make_ray_curriculum(curriculum) + + config = { + "env": "task_cartpole", + "num_gpus": 1, + "num_workers": 8, + "framework": "torch", + } + + tuner = tune.Tuner("APEX", param_space=config) + results = tuner.fit() diff --git a/examples/experimental/sb3_procgen_plr.py b/examples/experimental/sb3_procgen_plr.py new file mode 100644 index 00000000..f9cde9f9 --- /dev/null +++ b/examples/experimental/sb3_procgen_plr.py @@ -0,0 +1,116 @@ +from typing import Callable + +import gym +import procgen # noqa: F401 +import wandb +from stable_baselines3 import PPO +from stable_baselines3.common.callbacks import BaseCallback, CallbackList +from stable_baselines3.common.vec_env import (DummyVecEnv, VecMonitor, + VecNormalize) +from syllabus.core import (MultiProcessingSyncWrapper, + make_multiprocessing_curriculum) +from syllabus.curricula import CentralizedPrioritizedLevelReplay +from syllabus.examples.task_wrappers import ProcgenTaskWrapper +from wandb.integration.sb3 import WandbCallback + + +def make_env(task_queue, update_queue, start_level=0, num_levels=1): + def thunk(): + env = gym.make("procgen-bigfish-v0", distribution_mode="easy", start_level=start_level, num_levels=num_levels) + env = ProcgenTaskWrapper(env) + env = MultiProcessingSyncWrapper( + env, + task_queue, + update_queue, + update_on_step=False, + task_space=env.task_space, + ) + return env + return thunk + + +def wrap_vecenv(vecenv): + vecenv.is_vector_env = True + vecenv = VecMonitor(venv=vecenv, filename=None) + vecenv = VecNormalize(venv=vecenv, norm_obs=False, norm_reward=True) + return vecenv + + +class CustomCallback(BaseCallback): + """ + A custom callback that derives from ``BaseCallback``. + + :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages + """ + def __init__(self, curriculum, verbose=0): + super().__init__(verbose) + self.curriculum = curriculum + + def _on_step(self) -> bool: + tasks = self.training_env.venv.venv.venv.get_attr("task") + + update = { + "update_type": "on_demand", + "metrics": { + "value": self.locals["values"], + "next_value": self.locals["values"], + "rew": self.locals["rewards"], + "dones": self.locals["dones"], + "tasks": tasks, + }, + } + self.curriculum.update_curriculum(update) + return True + + +def linear_schedule(initial_value: float) -> Callable[[float], float]: + def func(progress_remaining: float) -> float: + return progress_remaining * initial_value + return func + + +run = wandb.init( + project="sb3", + entity="ryansullivan", + sync_tensorboard=True, # auto-upload sb3's tensorboard metrics + save_code=True, # optional +) + + +sample_env = gym.make("procgen-bigfish-v0") +sample_env = ProcgenTaskWrapper(sample_env) +curriculum = CentralizedPrioritizedLevelReplay(sample_env.task_space, num_processes=64, num_steps=256) +curriculum, task_queue, update_queue = make_multiprocessing_curriculum(curriculum) +venv = DummyVecEnv( + [ + make_env(task_queue, update_queue, num_levels=0) + for i in range(64) + ] +) +venv = wrap_vecenv(venv) + +model = PPO( + "CnnPolicy", + venv, + verbose=1, + n_steps=256, + learning_rate=linear_schedule(0.0005), + gamma=0.999, + gae_lambda=0.95, + n_epochs=3, + clip_range_vf=0.2, + ent_coef=0.01, + batch_size=256 * 64, + tensorboard_log="runs/testing" +) + +wandb_callback = WandbCallback( + model_save_path=f"models/{run.id}", + verbose=2, +) +plr_callback = CustomCallback(curriculum) +callback = CallbackList([wandb_callback, plr_callback]) +model.learn( + 25000000, + callback=callback, +) diff --git a/examples/experimental/torchbeast_nethack.py b/examples/experimental/torchbeast_nethack.py new file mode 100644 index 00000000..21767b1f --- /dev/null +++ b/examples/experimental/torchbeast_nethack.py @@ -0,0 +1,1280 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is an example self-contained agent running NLE based on MonoBeast. + +import argparse +import logging +import math +import os +import pprint +import threading +import time +import timeit +import traceback + +import numpy as np +import wandb +from syllabus.core import (MultiProcessingSyncWrapper, + make_multiprocessing_curriculum) +from syllabus.curricula import LearningProgressCurriculum +from syllabus.examples import NethackTaskWrapper + +# Necessary for multithreading. +os.environ["OMP_NUM_THREADS"] = "1" + +try: + import torch + from torch import multiprocessing as mp + from torch import nn + from torch.nn import functional as F +except ImportError: + logging.exception( + "PyTorch not found. Please install the agent dependencies with " + '`pip install "nle[agent]"`' + ) + +import gym # noqa: E402 +import nle # noqa: F401, E402 +from nle import nethack # noqa: E402 +from nle.agent import vtrace # noqa: E402 + + +def parse_args(): + # yapf: disable + parser = argparse.ArgumentParser(description="PyTorch Scalable Agent") + + parser.add_argument("--env", type=str, default="NetHackScore-v0", + help="Gym environment.") + parser.add_argument("--mode", default="train", + choices=["train", "test", "test_render"], + help="Training or test mode.") + parser.add_argument("--profile", action="store_true", + help="Profile main process.") + parser.add_argument("--profile_worker", action="store_true", + help="Profile worker process.") + + # Training settings. + parser.add_argument("--disable_checkpoint", action="store_true", + help="Disable saving checkpoint.") + parser.add_argument("--savedir", default="~/torchbeast/", + help="Root dir where experiment data will be saved.") + parser.add_argument("--num_actors", default=8, type=int, metavar="N", + help="Number of actors (default: 4).") + parser.add_argument("--total_steps", default=100000, type=int, metavar="T", + help="Total environment steps to train for.") + parser.add_argument("--batch_size", default=16, type=int, metavar="B", + help="Learner batch size.") + parser.add_argument("--unroll_length", default=80, type=int, metavar="T", + help="The unroll length (time dimension).") + parser.add_argument("--num_buffers", default=None, type=int, + metavar="N", help="Number of shared-memory buffers.") + parser.add_argument("--num_learner_threads", "--num_threads", default=2, type=int, + metavar="N", help="Number learner threads.") + parser.add_argument("--disable_cuda", action="store_true", + help="Disable CUDA.") + parser.add_argument("--use_lstm", action="store_true", + help="Use LSTM in agent model.") + parser.add_argument("--save_ttyrec_every", default=1000, type=int, + metavar="N", help="Save ttyrec every N episodes.") + parser.add_argument("--save_video", action="store_true", + help="Save and log video during training.") + + # Curriculum Settings + parser.add_argument("--curriculum", action="store_true", + help="Use Syllabus curricula.") + + # Testing settings + parser.add_argument("--reward_frames", action="store_true", + help="Only print reward frames and show inventory.") + parser.add_argument("--item_frames", action="store_true", + help="Only print frames where the agent picks up items and show inventory.") + parser.add_argument("--message", action="store_true", + help="Set to true without the above options to display only the messages.") + parser.add_argument("--custompath", type=str, + help="Set a custom path to draw a tar test file from.") + + # Weights and Biases settings + parser.add_argument("--exp_name", type=str, default="nle_baseline", + help="Set name for wandb experiment.") + parser.add_argument("--wandb_id", default=1, type=int, + help="Set id for wandb experiment.") + + # Loss settings. + parser.add_argument("--entropy_cost", default=0.0006, + type=float, help="Entropy cost/multiplier.") + parser.add_argument("--baseline_cost", default=0.5, + type=float, help="Baseline cost/multiplier.") + parser.add_argument("--discounting", default=0.99, + type=float, help="Discounting factor.") + parser.add_argument("--reward_clipping", default="abs_one", + choices=["abs_one", "none"], + help="Reward clipping.") + + # Optimizer settings. + parser.add_argument("--learning_rate", default=0.00048, + type=float, metavar="LR", help="Learning rate.") + parser.add_argument("--alpha", default=0.99, type=float, + help="RMSProp smoothing constant.") + parser.add_argument("--momentum", default=0, type=float, + help="RMSProp momentum.") + parser.add_argument("--epsilon", default=0.01, type=float, + help="RMSProp epsilon.") + parser.add_argument("--grad_norm_clipping", default=40.0, type=float, + help="Global gradient norm clip.") + # yapf: enable + args = parser.parse_args() + args.exp_name = f"nethack__{args.exp_name}__{int(time.time())}" + return args + + +logging.basicConfig( + format=("[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"), + level=logging.INFO, + #filename="monobeast.log" +) + + +def nested_map(f, n): + if isinstance(n, tuple) or isinstance(n, list): + return n.__class__(nested_map(f, sn) for sn in n) + elif isinstance(n, dict): + return {k: nested_map(f, v) for k, v in n.items()} + else: + return f(n) + + +def compute_baseline_loss(advantages): + return 0.5 * torch.sum(advantages**2) + + +def compute_entropy_loss(logits): + """Return the entropy loss, i.e., the negative entropy of the policy.""" + policy = F.softmax(logits, dim=-1) + log_policy = F.log_softmax(logits, dim=-1) + return torch.sum(policy * log_policy) + + +def compute_policy_gradient_loss(logits, actions, advantages): + cross_entropy = F.nll_loss( + F.log_softmax(torch.flatten(logits, 0, 1), dim=-1), + target=torch.flatten(actions, 0, 1), + reduction="none", + ) + cross_entropy = cross_entropy.view_as(advantages) + return torch.sum(cross_entropy * advantages.detach()) + + +def create_env( + name, + *args, + chars=False, + task_queue: mp.SimpleQueue = None, + complete_queue: mp.SimpleQueue = None, + step_queue: mp.SimpleQueue = None, + observation_keys=None, + **kwargs +): + if flags.curriculum: + observation_keys = ("glyphs", "blstats", "message") + #observation_keys = ("glyphs", "blstats", "inv_glyphs", "inv_strs", "inv_letters", "inv_oclasses", + # "message", "tty_chars", "tty_colors", "tty_cursor", "internal") + if chars: + observation_keys += ("chars",) + + env = gym.make(name, *args, observation_keys=observation_keys, penalty_step=0.0, **kwargs) + env = NethackTaskWrapper(env) + env = MultiProcessingSyncWrapper(env, + task_queue, + complete_queue, + step_queue=step_queue, + update_on_step=True, + default_task=0, + task_space=env.task_space) + else: + env = gym.make(name, *args, observation_keys=observation_keys, **kwargs) + env = ResettingEnvironment(env) + return env + + +def act( + flags, + actor_index: int, + free_queue: mp.SimpleQueue, + full_queue: mp.SimpleQueue, + model: torch.nn.Module, + buffers, + initial_agent_state_buffers, + task_queue: mp.SimpleQueue = None, + complete_queue: mp.SimpleQueue = None, + step_queue: mp.SimpleQueue = None, + + video_dir=None +): + try: + if flags.profile_worker and actor_index == 0: + import cProfile + from pstats import Stats + pr = cProfile.Profile() + pr.enable() + + logging.info("Actor %i started.", actor_index) + + gym_env = create_env( + flags.env, + savedir=flags.rundir, + save_ttyrec_every=flags.save_ttyrec_every, + observation_keys=("glyphs", "blstats"), + task_queue=task_queue, + complete_queue=complete_queue, + step_queue=step_queue, + ) + + env = gym_env + env_output = env.initial() + agent_state = model.initial_state(batch_size=1) + agent_output, _ = model(env_output, agent_state) + env_output = env.step(agent_output["action"]) + + while True: + index = free_queue.get() + if index is None: + break + + # Write old rollout end. + for key in env_output: + buffers[key][index][0, ...] = env_output[key] + for key in agent_output: + buffers[key][index][0, ...] = agent_output[key] + for i, tensor in enumerate(agent_state): + initial_agent_state_buffers[index][i][...] = tensor + + # Do new rollout. + for t in range(flags.unroll_length): + with torch.no_grad(): + agent_output, agent_state = model(env_output, agent_state) + + env_output = env.step(agent_output["action"]) + if flags.curriculum and "info" in env_output: + info = env_output["info"] + + for key in env_output: + buffers[key][index][t + 1, ...] = env_output[key] + for key in agent_output: + buffers[key][index][t + 1, ...] = agent_output[key] + + full_queue.put(index) + + if flags.profile_worker and actor_index == 0: + pr.disable() + stats = Stats(pr) + stats.sort_stats('cumtime').print_stats(200) + + except KeyboardInterrupt: + pass # Return silently. + except Exception: + logging.error("Exception in worker process %i", actor_index) + traceback.print_exc() + print() + raise + + +def get_batch( + flags, + free_queue: mp.SimpleQueue, + full_queue: mp.SimpleQueue, + buffers, + initial_agent_state_buffers, + lock=threading.Lock(), +): + with lock: + indices = [full_queue.get() for _ in range(flags.batch_size)] + batch = { + key: torch.stack([buffers[key][m] for m in indices], dim=1) for key in buffers + } + + initial_agent_state = ( + torch.cat(ts, dim=1) + for ts in zip(*[initial_agent_state_buffers[m] for m in indices]) + ) + for m in indices: + free_queue.put(m) + batch = {k: t.to(device=flags.device, non_blocking=True) for k, t in batch.items()} + initial_agent_state = tuple( + t.to(device=flags.device, non_blocking=True) for t in initial_agent_state + ) + return batch, initial_agent_state + + +def learn( + flags, + actor_model, + model, + batch, + initial_agent_state, + optimizer, + scheduler, + lock=threading.Lock(), # noqa: B008 +): + """Performs a learning (optimization) step.""" + with lock: + learner_outputs, unused_state = model(batch, initial_agent_state) + + # Take final value function slice for bootstrapping. + bootstrap_value = learner_outputs["baseline"][-1] + + # Move from obs[t] -> action[t] to action[t] -> obs[t]. + batch = {key: tensor[1:] for key, tensor in batch.items()} + learner_outputs = {key: tensor[:-1] for key, tensor in learner_outputs.items()} + + rewards = batch["reward"] + if flags.reward_clipping == "abs_one": + clipped_rewards = torch.clamp(rewards, -1, 1) + elif flags.reward_clipping == "none": + clipped_rewards = rewards + + discounts = (~batch["done"]).float() * flags.discounting + + vtrace_returns = vtrace.from_logits( + behavior_policy_logits=batch["policy_logits"], + target_policy_logits=learner_outputs["policy_logits"], + actions=batch["action"], + discounts=discounts, + rewards=clipped_rewards, + values=learner_outputs["baseline"], + bootstrap_value=bootstrap_value, + ) + + pg_loss = compute_policy_gradient_loss( + learner_outputs["policy_logits"], + batch["action"], + vtrace_returns.pg_advantages, + ) + baseline_loss = flags.baseline_cost * compute_baseline_loss( + vtrace_returns.vs - learner_outputs["baseline"] + ) + entropy_loss = flags.entropy_cost * compute_entropy_loss( + learner_outputs["policy_logits"] + ) + + total_loss = pg_loss + baseline_loss + entropy_loss + + episode_returns = batch["episode_return"][batch["done"]] + episode_steps = batch["episode_step"][batch["done"]] + + stats = { + "episode_returns": tuple(episode_returns.cpu().numpy()), + "episode_lengths": tuple(episode_steps.cpu().numpy()), + "mean_episode_length": torch.mean(episode_steps).item(), + "mean_episode_return": torch.mean(episode_returns).item(), + "total_loss": total_loss.item(), + "pg_loss": pg_loss.item(), + "baseline_loss": baseline_loss.item(), + "entropy_loss": entropy_loss.item(), + } + + optimizer.zero_grad() + total_loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping) + optimizer.step() + scheduler.step() + + actor_model.load_state_dict(model.state_dict()) + return stats + + +def create_buffers(flags, observation_space, num_actions, num_overlapping_steps=1): + size = (flags.unroll_length + num_overlapping_steps,) + + # Get specimens to infer shapes and dtypes. + samples = {k: torch.from_numpy(v) for k, v in observation_space.sample().items()} + + specs = { + key: dict(size=size + sample.shape, dtype=sample.dtype) + for key, sample in samples.items() + } + specs.update( + reward=dict(size=size, dtype=torch.float32), + done=dict(size=size, dtype=torch.bool), + episode_return=dict(size=size, dtype=torch.float32), + episode_step=dict(size=size, dtype=torch.float32), + policy_logits=dict(size=size + (num_actions,), dtype=torch.float32), + baseline=dict(size=size, dtype=torch.float32), + last_action=dict(size=size, dtype=torch.int64), + action=dict(size=size, dtype=torch.int64), + ) + buffers = {key: [] for key in specs} + for _ in range(flags.num_buffers): + for key in buffers: + buffers[key].append(torch.empty(**specs[key]).share_memory_()) + return buffers + + +def _format_observations(observation): + observations = {} + for key in list(observation.keys()): + entry = observation[key] + if isinstance(entry, np.ndarray): + entry = torch.from_numpy(entry) + entry = entry.view((1, 1) + entry.shape) # (...) -> (T,B,...). + observations[key] = entry + return observations + + +class ResettingEnvironment: + """Turns a Gym environment into something that can be step()ed indefinitely.""" + + def __init__(self, gym_env): + self.gym_env = gym_env + self.episode_return = None + self._copy_gym_properties() + + def _copy_gym_properties(self): + self.action_space = self.gym_env.action_space + self.observation_space = self.gym_env.observation_space + self.reward_range = self.gym_env.reward_range + self.metadata = self.gym_env.metadata + if flags.curriculum: + self.task_space = self.gym_env.task_space + + def initial(self): + initial_reward = torch.zeros(1, 1) + # This supports only single-tensor actions ATM. + initial_last_action = torch.zeros(1, 1, dtype=torch.int64) + self.episode_return = torch.zeros(1, 1) + self.episode_step = torch.zeros(1, 1, dtype=torch.float32) + initial_done = torch.ones(1, 1, dtype=torch.uint8) + + result = _format_observations(self.gym_env.reset()) + result.update( + reward=initial_reward, + done=initial_done, + episode_return=self.episode_return, + episode_step=self.episode_step, + last_action=initial_last_action, + ) + return result + + def step(self, action): + observation, reward, done, info = self.gym_env.step(action.item()) + self.episode_step += 1 + self.episode_return += reward + episode_step = self.episode_step + episode_return = self.episode_return + + if done: + observation = self.gym_env.reset() + self.episode_return = torch.zeros(1, 1) + self.episode_step = torch.zeros(1, 1, dtype=torch.float32) + + result = _format_observations(observation) + + reward = torch.tensor(reward).view(1, 1) + done = torch.tensor(done).view(1, 1) + if flags.curriculum: + result.update( + reward=reward, + done=done, + episode_return=episode_return, + episode_step=episode_step, + last_action=action, + ) + else: + result.update( + reward=reward, + done=done, + episode_return=episode_return, + episode_step=episode_step, + last_action=action, + ) + + return result + + def close(self): + self.gym_env.close() + + +def parse_logpaths(flags): + flags.savedir = os.path.expandvars(os.path.expanduser(flags.savedir)) + + if flags.exp_name: + rundir = os.path.join(flags.savedir, f"torchbeast-{flags.exp_name}-{flags.wandb_id}") + else: + rundir = os.path.join(flags.savedir, "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")) + + checkpointpath = os.path.join(rundir, "model.tar") + + # TODO: Check if run name + id exists, and resume + resume_checkpoint = None + if os.path.exists(rundir): + resume_checkpoint = torch.load(checkpointpath) + # TODO: Make sure this doesn't overwrite anything important, + # or that we want to be able to change + # flags = Namespace(**resume_checkpoint["flags"]) + else: + os.makedirs(rundir) + + logfile = open(os.path.join(rundir, "logs.tsv"), "a", buffering=1) + logging.info("Logging results to %s", rundir) + + symlink = os.path.join(flags.savedir, "latest") + try: + if os.path.islink(symlink): + os.remove(symlink) + if not os.path.exists(symlink): + os.symlink(rundir, symlink) + logging.info("Symlinked log directory: %s", symlink) + except OSError: + raise + + flags.rundir = rundir + return flags, rundir, checkpointpath, resume_checkpoint, logfile + + +def train(flags, wandb_run=None): # pylint: disable=too-many-branches, too-many-statements + # Set all filepaths using provided arguments + flags, rundir, checkpointpath, resume_checkpoint, logfile = parse_logpaths(flags) + + if flags.num_buffers is None: # Set sensible default for num_buffers. + flags.num_buffers = max(2 * flags.num_actors, flags.batch_size) + if flags.num_actors >= flags.num_buffers: + raise ValueError("num_buffers should be larger than num_actors") + if flags.num_buffers < flags.batch_size: + raise ValueError("num_buffers should be larger than batch_size") + + T = flags.unroll_length + B = flags.batch_size + + flags.device = None + if not flags.disable_cuda and torch.cuda.is_available(): + logging.info("Using CUDA.") + flags.device = torch.device("cuda") + else: + logging.info("Not using CUDA.") + flags.device = torch.device("cpu") + + sample_env = create_env(flags.env, observation_keys=("glyphs", "blstats")) + observation_space = sample_env.observation_space + action_space = sample_env.action_space + if flags.curriculum: + task_space = sample_env.task_space + + model = Net(observation_space, action_space.n, flags.use_lstm, goal=flags.curriculum) + buffers = create_buffers(flags, observation_space, model.num_actions) + + model.share_memory() + + def lr_lambda(epoch): + return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps + + # Resume training + # if resume_checkpoint: + # # TODO: Fix resume code + # #curriculum = LearningProgressCurriculum(tasks=resume_checkpoint["curriculum"]) + # if flags.curriculum: + # curriculum = LearningProgressCurriculum(task_space=task_space) + # model.load_state_dict(resume_checkpoint["model"]) + # learner_model = Net(observation_space, action_space.n, flags.use_lstm, goal=flags.curriculum).to( + # device=flags.device + # ) + # learner_model.load_state_dict(model.state_dict()) + # optimizer = torch.optim.RMSprop( + # learner_model.parameters(), + # lr=flags.learning_rate, + # momentum=flags.momentum, + # eps=flags.epsilon, + # alpha=flags.alpha, + # ) + # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + # scheduler.load_state_dict(resume_checkpoint["scheduler"]) + # optimizer.load_state_dict(resume_checkpoint['optimizer']) + # step = resume_checkpoint["step"] + # + # else: + task_queue, complete_queue, step_queue = None, None, None + if flags.curriculum: + name_list = None + if isinstance(sample_env.task_space, gym.spaces.Discrete): + task_list = sample_env.gym_env.task_list + name_list = [task_list[idx].__name__ for idx in range(len(task_list))] + curriculum, task_queue, complete_queue, step_queue = make_multiprocessing_curriculum(LearningProgressCurriculum, + task_space, + random_start_tasks=0, + task_names=name_list) + + learner_model = Net(observation_space, action_space.n, flags.use_lstm, goal=flags.curriculum).to(device=flags.device) + learner_model.load_state_dict(model.state_dict()) + + optimizer = torch.optim.RMSprop( + learner_model.parameters(), + lr=flags.learning_rate, + momentum=flags.momentum, + eps=flags.epsilon, + alpha=flags.alpha, + ) + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + step = 0 + + # Add initial RNN state. + initial_agent_state_buffers = [] + for _ in range(flags.num_buffers): + state = model.initial_state(batch_size=1) + for t in state: + t.share_memory_() + initial_agent_state_buffers.append(state) + + del sample_env # End this before forking. + + actor_processes = [] + ctx = mp.get_context("fork") + free_queue = ctx.SimpleQueue() + full_queue = ctx.SimpleQueue() + + for i in range(flags.num_actors): + actor = ctx.Process( + target=act, + args=( + flags, i, + free_queue, full_queue, + model, buffers, + initial_agent_state_buffers, + task_queue, complete_queue, step_queue, + rundir, + ), + name="Actor-%i" % i, + ) + actor.start() + actor_processes.append(actor) + + if flags.exp_name: + wandb.config = { + "learning_rate": flags.learning_rate, + "epsilon": flags.epsilon, + "alpha": flags.alpha, + "momentum": flags.momentum, + } + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + stat_keys = ["total_loss", "mean_episode_return", "pg_loss", "baseline_loss", "entropy_loss"] + logfile.write("# Step\t%s\n" % "\t".join(stat_keys)) + + all_stats = [] + stats = {} + + def batch_and_learn(i, lock=threading.Lock()): + """Thread target for the learning process.""" + nonlocal step, stats, all_stats + while step < flags.total_steps: + batch, agent_state = get_batch(flags, free_queue, full_queue, buffers, initial_agent_state_buffers) + stats = learn(flags, model, learner_model, batch, agent_state, optimizer, scheduler) + + all_stats.append(stats) + with lock: + logfile.write("%i\t" % step) + logfile.write("\t".join(str(stats[k]) for k in stat_keys)) + logfile.write("\n") + step += T * B + + for m in range(flags.num_buffers): + free_queue.put(m) + + threads = [] + for i in range(flags.num_learner_threads): + thread = threading.Thread( + target=batch_and_learn, + name="batch-and-learn-%d" % i, + args=(i,), + daemon=True, # To support KeyboardInterrupt below. + ) + thread.start() + threads.append(thread) + + def checkpoint(): + if flags.disable_checkpoint: + return + logging.info("Saving checkpoint to %s", checkpointpath) + torch.save( + { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "flags": vars(flags), + "step": step + }, + checkpointpath, + ) + + timer = timeit.default_timer + recent_sps = [1.0] * 5 + try: + if flags.curriculum: + curriculum.log_metrics(step) + last_checkpoint_time = timer() + while step < flags.total_steps: + start_step = step + start_time = timer() + time.sleep(10) + + if timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. + checkpoint() + last_checkpoint_time = timer() + #wandb.gym.monitor() + + # # Combine stats + all_stats_dict = {} + if len(all_stats) > 0: + # Iterate through keys and combine based on data type + for key, value in all_stats[0].items(): + all_stats_dict[key] = 0 + stat_list = [] + if isinstance(value, (int, float, np.uint8, np.float32)): + # Average values + for stat_dict in all_stats: + stat_value = stat_dict.get(key) + if stat_value is not None and not np.isnan(stat_value) and not math.isnan(stat_value): + stat_list.append(stat_value) + all_stats_dict[key] = (sum(stat_list) / len(stat_list)) if len(stat_list) > 0 else 0 + elif isinstance(value, (list, tuple)): + # Combine lists + for stat_dict in all_stats: + stat_value = stat_dict.get(key) + if stat_value is not None and stat_value != () and stat_value != []: + stat_list += stat_value + all_stats_dict[key] = stat_list + all_stats = [] + + # Remove clutter + if "episode_returns" in all_stats_dict: + del all_stats_dict["episode_returns"] + if "episode_lengths" in all_stats_dict: + del all_stats_dict["episode_lengths"] + + # Log run data to weights and biases + if flags.exp_name: + wandb_stats = all_stats_dict + for key, value in wandb_stats.items(): + if isinstance(value, (int, float, np.uint8, np.float32)): + wandb_stats[key] = 0.0 if np.isnan(value) else value + if not flags.curriculum: + wandb_stats["mean_score_return"] = wandb_stats["mean_episode_return"] + wandb_stats["learning_rate"] = scheduler.get_last_lr()[0] + # task_table.add_data(step, str(curriculum.export_task_names())) + wandb.log(wandb_stats, step=step) + + sps = (step - start_step) / (timer() - start_time) + + if stats.get("episode_returns", None): + mean_return = ( + "Return per episode: %.1f. " % stats["mean_episode_return"] + ) + else: + mean_return = "" + total_loss = stats.get("total_loss", float("inf")) + log_str = "Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s" + log_args = [step, sps, total_loss, mean_return, pprint.pformat(all_stats_dict)] + # log_str = "Steps %i @ %.1f SPS. Loss %f. %s" + # log_args = [step, sps, total_loss, mean_return] + # if flags.curriculum: + # # TODO: Fix this + # task_names = curriculum.export_task_names()[:] + # log_str += "\nTasks: %s\nL_prog: %s\nP_fast: %s" + # log_args.append(task_names) + # # Get learning progress metric + # lps = curriculum.metric_for_tasks(task_names, metric="lp") + # lps = [f"{lp:.4f}" for lp in lps] + # log_args.append(lps) + # # Get recent estimate of success rates + # pfasts = curriculum.metric_for_tasks(task_names, metric="p_fast") + # pfasts = [f"{pfast:.4f}" for pfast in pfasts] + # log_args.append(pfasts) + logging.info(log_str, *log_args) + if flags.curriculum: + curriculum.log_metrics(step) + + # Stop training if sps remains 0 for too long + if total_loss != float("inf"): + for i in reversed(range(1, len(recent_sps))): + recent_sps[i] = recent_sps[i-1] + recent_sps[0] = sps + if sum(recent_sps) == 0.0: + return + + except KeyboardInterrupt: + logging.warning("Quitting.") + return # Try joining actors then quit. + else: + for thread in threads: + thread.join() + logging.info("Learning finished after %d steps.", step) + finally: + for _ in range(flags.num_actors): + free_queue.put(None) + for actor in actor_processes: + actor.join(timeout=1) + + checkpoint() + logfile.close() + + +def test(flags, num_episodes=1): + flags.savedir = os.path.expandvars(os.path.expanduser(flags.savedir)) + print("savedir:" + str(flags.savedir)) + checkpointpath = flags.custompath if flags.custompath else os.path.join(flags.savedir, "latest", "model.tar") + + #gym_env = create_env(flags.env, save_ttyrecs=flags.save_ttyrecs) + observation_keys = ("glyphs", "blstats") + observation_keys = ("glyphs", "blstats", "inv_glyphs", "inv_strs", "inv_letters", "inv_oclasses", + "message", "tty_chars", "tty_colors", "tty_cursor") + gym_env = create_env(flags.env, observation_keys=observation_keys) + env = ResettingEnvironment(gym_env) + model = Net(gym_env.observation_space, gym_env.action_space.n, flags.use_lstm, goal=flags.curriculum) + model.eval() + checkpoint = torch.load(checkpointpath, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + + observation = env.initial() + returns = [] + + agent_state = model.initial_state(batch_size=1) + last_inv = [] + + while len(returns) < num_episodes: + if flags.mode == "test_render": + if flags.item_frames or flags.reward_frames: + #The following lines parse the inventory and store it as an array in all_items_txt + all_items_text = [] + all_items_encoded = observation["inv_strs"][0][0].numpy() + for x in range(55): + if(all_items_encoded[x][0] != 0): + all_items_text.append(''.join([str((chr(elem))) for elem in all_items_encoded[x][all_items_encoded[x] != 0]])) + if observation["reward"].item() > 0 and flags.reward_frames or all_items_text != last_inv and flags.item_frames: + env.gym_env.render() + print("Reward is %d" % observation["reward"].item()) + print(all_items_text) + last_inv = all_items_text + elif flags.message: + #Parse and print the message + message_encoded = observation["message"][0][0] + message = "" + for x in range(256): + message += chr(message_encoded[x].item()) + print(message) + else: + env.gym_env.render() + policy_outputs, agent_state = model(observation, agent_state) + observation = env.step(policy_outputs["action"]) + if observation["done"].item(): + last_inv = [] + returns.append(observation["episode_return"].item()) + logging.info( + "Episode ended after %d steps. Return: %.1f", + observation["episode_step"].item(), + observation["episode_return"].item(), + ) + env.close() + logging.info( + "Average returns over %i steps: %.1f", num_episodes, sum(returns) / len(returns) + ) + + +class RandomNet(nn.Module): + def __init__(self, observation_shape, num_actions, use_lstm): + super(RandomNet, self).__init__() + del observation_shape, use_lstm + self.num_actions = num_actions + self.theta = torch.nn.Parameter(torch.zeros(self.num_actions)) + + def forward(self, inputs, core_state): + # print(inputs) + T, B, *_ = inputs["observation"].shape + zeros = self.theta * 0 + # set logits to 0 + policy_logits = zeros[None, :].expand(T * B, -1) + # set baseline to 0 + baseline = policy_logits.sum(dim=1).view(-1, B) + + # sample random action + action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1).view( + T, B + ) + policy_logits = policy_logits.view(T, B, self.num_actions) + return ( + dict(policy_logits=policy_logits, baseline=baseline, action=action), + core_state, + ) + + def initial_state(self, batch_size): + return () + + +def _step_to_range(delta, num_steps): + """ + Range of `num_steps` values separated by distance `delta` centered around zero. + Given an image with x,y in [-1, 1], return represents the crop range and sampling points on one axis. + """ + delta_range = delta * torch.arange(-(num_steps // 2), (num_steps + 1) // 2) + return delta_range + + +class Crop(nn.Module): + """Helper class for NetHackNet below.""" + + def __init__(self, height, width, height_target, width_target): + super(Crop, self).__init__() + self.width = width # 79 + self.height = height # 21 + self.width_target = width_target # 9 + self.height_target = height_target # 9 + + # Create row-wise sampling grid and repeat across height_target rows + width_grid = _step_to_range(2 / (self.width - 1), self.width_target)[None, :].expand(self.height_target, -1) + # Create column-wise sampling grid and repeat across width_target columns + height_grid = _step_to_range(2 / (self.height - 1), height_target)[:, None].expand(-1, self.width_target) + + # "clone" necessary, https://github.com/pytorch/pytorch/issues/34880 + self.register_buffer("width_grid", width_grid.clone()) + self.register_buffer("height_grid", height_grid.clone()) + + def forward(self, inputs, coordinates): + """Calculates centered crop around given x,y coordinates. + Args: + inputs [B x H x W] + coordinates [B x 2] x,y coordinates + Returns: + [B x H' x W'] inputs cropped and centered around x,y coordinates. + """ + + assert inputs.shape[1] == self.height + assert inputs.shape[2] == self.width + + # Add another axis at 1 + inputs = inputs[:, None, :, :].float() + + # Extract x,y coordinates for each sample + x = coordinates[:, 0] + y = coordinates[:, 1] + + # Recenter x and y values around 0 and normalize to range [-1, 1] + x_shift = 2 / (self.width - 1) * (x.float() - self.width // 2) + y_shift = 2 / (self.height - 1) * (y.float() - self.height // 2) + + # Shifts sampling grid to be centered around x, y + grid = torch.stack( + [ + self.width_grid[None, :, :] + x_shift[:, None, None], + self.height_grid[None, :, :] + y_shift[:, None, None], + ], + dim=3, + ) + + # TODO: only cast to int if original tensor was int + return ( + torch.round(F.grid_sample(inputs, grid, align_corners=True)) + .squeeze(1) + .long() + ) + + +class NetHackNet(nn.Module): + def __init__( + self, + observation_shape, + num_actions, + use_lstm, + embedding_dim=32, + crop_dim=9, + num_layers=5, + goal=False + ): + super(NetHackNet, self).__init__() + self.goal = goal + + self.glyph_shape = observation_shape["glyphs"].shape # (21, 79) + self.blstats_size = observation_shape["blstats"].shape[0] # 27 + if self.goal: + self.goal_size = observation_shape["goal"].shape[0] # 785? + + self.num_actions = num_actions # 23 + self.use_lstm = use_lstm + + self.H = self.glyph_shape[0] # 21 + self.W = self.glyph_shape[1] # 79 + + self.k_dim = embedding_dim # 32 + self.h_dim = 512 + + self.crop_dim = crop_dim # 9 + + self.crop = Crop(self.H, self.W, self.crop_dim, self.crop_dim) + + self.embed = nn.Embedding(nethack.MAX_GLYPH, self.k_dim) # 5976, 32 + + K = embedding_dim # number of input filters + F = 3 # filter dimensions + S = 1 # stride + P = 1 # padding + M = 16 # number of intermediate filters + Y = 8 # number of output filters + L = num_layers # number of convnet layers # 5 + + in_channels = [K] + [M] * (L - 1) # [32, 16, 16, 16, 16] + out_channels = [M] * (L - 1) + [Y] # [16, 16, 16, 16, 8] + + def interleave(xs, ys): + return [val for pair in zip(xs, ys) for val in pair] + + conv_extract = [ + nn.Conv2d( + in_channels=in_channels[i], + out_channels=out_channels[i], + kernel_size=(F, F), + stride=S, + padding=P, + ) + for i in range(L) + ] + + # Create a sequential net of alternating Conv2d and ELU layers + self.extract_representation = nn.Sequential( + *interleave(conv_extract, [nn.ELU()] * len(conv_extract)) + ) + + # CNN crop model. + conv_extract_crop = [ + nn.Conv2d( + in_channels=in_channels[i], + out_channels=out_channels[i], + kernel_size=(F, F), + stride=S, + padding=P, + ) + for i in range(L) + ] + + # Create a sequential net of alternating Conv2d and ELU layers + self.extract_crop_representation = nn.Sequential( + *interleave(conv_extract_crop, [nn.ELU()] * len(conv_extract)) + ) + + # Blstats output dim + out_dim = self.k_dim # 32 + # Map glyphs output dim + out_dim += self.H * self.W * Y # 32 + 21 * 79 * 8 = 13304 + # Cropped map glyphs output dim + out_dim += self.crop_dim**2 * Y # 13304 + 9 * 9 * 8 = 13952 + if self.goal: + # Goal output dim + out_dim += int(self.goal_size / 2) # 13952 + 392 = 14344 + + # Blstats encoding 27 to 32 + self.embed_blstats = nn.Sequential( + nn.Linear(self.blstats_size, self.k_dim), + nn.ReLU(), + nn.Linear(self.k_dim, self.k_dim), + nn.ReLU(), + ) + + # Goal encoding 27 to 32 + if self.goal: + goal_dim = int(self.goal_size / 2) + self.embed_goals = nn.Sequential( + nn.Linear(self.goal_size, goal_dim), + nn.ReLU(), + nn.Linear(goal_dim, goal_dim), + nn.ReLU(), + ) + + # Fully connected from embedding to policy 13304 to 512 + self.fc = nn.Sequential( + nn.Linear(out_dim, self.h_dim), + nn.ReLU(), + nn.Linear(self.h_dim, self.h_dim), + nn.ReLU(), + ) + + if self.use_lstm: + # LSTM 512 to 512 + self.core = nn.LSTM(self.h_dim, self.h_dim, num_layers=1) + + # Policy 512 to 23 + self.policy = nn.Linear(self.h_dim, self.num_actions) + # Baseline 512 to 1 + self.baseline = nn.Linear(self.h_dim, 1) + + def initial_state(self, batch_size=1): + if self.use_lstm: + return tuple( + torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size) + for _ in range(2) + ) + return tuple() + + def _select(self, embed, x): + # Work around slow backward pass of nn.Embedding, see + # https://github.com/pytorch/pytorch/issues/24912 + out = embed.weight.index_select(0, x.reshape(-1)) + return out.reshape(x.shape + (-1,)) + + def forward(self, env_outputs, core_state): + # Set up glyphs + # -- [T x B x H x W] + glyphs = env_outputs["glyphs"] # [1, 1, 21, 79] + T, B, *_ = glyphs.shape + # -- [B' x H x W] + glyphs = torch.flatten(glyphs, 0, 1) # Merge time and batch. # [1, 21, 79] + # -- [B x H x W] + glyphs = glyphs.long() + + # Bottom line stats and extract coordinates + # -- [T x B x F] + blstats = env_outputs["blstats"] # [1, 1, 27] + # -- [B' x F] + blstats = blstats.view(T * B, -1).float() # [1, 27] + # -- [B x 2] x,y coordinates + coordinates = blstats[:, :2] # [1, 2] + # TODO ??? + # coordinates[:, 0].add_(-1) + # -- [B x F] + # TODO: Remove? + blstats = blstats.view(T * B, -1).float() # [1, 27] + # -- [B x K] + blstats_emb = self.embed_blstats(blstats) # [1, 32] + assert blstats_emb.shape[0] == T * B + reps = [blstats_emb] + + # Cropped map glyphs + # -- [B x H' x W'] + # Crop glyphs observation around player x,y coordinates + crop = self.crop(glyphs, coordinates) # [1, 9, 9] + # -- [B x H' x W' x K] + # Embed glyphs + crop_emb = self._select(self.embed, crop) # [1, 9, 9, 32] + # CNN crop model. + # -- [B x K x W' x H'] + crop_emb = crop_emb.transpose(1, 3) # -- TODO: slow? # [1, 32, 9, 9] + # -- [B x W' x H' x K] + # Convolutional pass to get representation of cropped region # [1, 8, 9, 9] + crop_rep = self.extract_crop_representation(crop_emb) + # -- [B x K'] + crop_rep = crop_rep.view(T * B, -1) # [1, 648] + assert crop_rep.shape[0] == T * B + reps.append(crop_rep) + + # -- [B x H x W x K] + # Full map glyphs + glyphs_emb = self._select(self.embed, glyphs) # [1, 21, 79, 32] + # -- [B x K x W x H] + glyphs_emb = glyphs_emb.transpose(1, 3) # -- TODO: slow? # [1, 32, 79, 21] + # -- [B x W x H x K] + glyphs_rep = self.extract_representation(glyphs_emb) # [1, 8, 79, 21] + # -- [B x K'] + glyphs_rep = glyphs_rep.view(T * B, -1) # [1, 13272] + assert glyphs_rep.shape[0] == T * B + # -- [B x K''] + reps.append(glyphs_rep) + + # TODO: Goals + if self.goal: + # -- [T x B x F] + goals = env_outputs["goal"] # [1, 1, 785] + # -- [B' x F] + goals = goals.view(T * B, -1).float() # [1, 785] + # -- [B x F] + # TODO: Remove? + goals = goals.view(T * B, -1).float() # [1, 785] + # -- [B x K] + goals_emb = self.embed_goals(goals) # [1, 32] + assert goals_emb.shape[0] == T * B + reps.append(goals_emb) + + # [32, 648, 13272] + st = torch.cat(reps, dim=1) + + # -- [B x K] + st = self.fc(st) # [1, 512] + + if self.use_lstm: + core_input = st.view(T, B, -1) + core_output_list = [] + notdone = (~env_outputs["done"]).float() + for input, nd in zip(core_input.unbind(), notdone.unbind()): + # Reset core state to zero whenever an episode ended. + # Make `done` broadcastable with (num_layers, B, hidden_size) + # states: + nd = nd.view(1, -1, 1) + core_state = tuple(nd * s for s in core_state) + output, core_state = self.core(input.unsqueeze(0), core_state) + core_output_list.append(output) + core_output = torch.flatten(torch.cat(core_output_list), 0, 1) + else: + core_output = st + + # -- [B x A] + policy_logits = self.policy(core_output) + # -- [B x A] + baseline = self.baseline(core_output) + + if self.training: + action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) + else: + # Don't sample when testing. + action = torch.argmax(policy_logits, dim=1) + + policy_logits = policy_logits.view(T, B, self.num_actions) + baseline = baseline.view(T, B) + action = action.view(T, B) + + return ( + dict(policy_logits=policy_logits, baseline=baseline, action=action), + core_state, + ) + + +Net = NetHackNet + + +def main(flags, wandb=None): + if flags.mode == "train": + train(flags, wandb_run=wandb) + else: + test(flags) + + +if __name__ == "__main__": + flags = parse_args() + + if flags.profile: + import cProfile + from pstats import Stats + pr = cProfile.Profile() + pr.enable() + + wandb_run = None + if flags.exp_name: + wandb_run = wandb.init( + project="syllabus", + entity="", + config=flags, + save_code=True, + name=flags.exp_name, + resume="allow", + id=f"{flags.exp_name}-{flags.wandb_id}" + ) + main(flags, wandb=wandb_run) + + if flags.profile: + pr.disable() + stats = Stats(pr) + stats.sort_stats('cumtime').print_stats(200) diff --git a/examples/task_wrappers/__init__.py b/examples/task_wrappers/__init__.py new file mode 100644 index 00000000..7d04cff4 --- /dev/null +++ b/examples/task_wrappers/__init__.py @@ -0,0 +1,23 @@ +import warnings + +from .cartpole_task_wrapper import CartPoleTaskWrapper + +try: + from .minigrid_task_wrapper import MinigridTaskWrapper +except ImportError as e: + warnings.warn(f"Unable to import the following minigrid dependencies: {e.name}") + +try: + from .minihack_task_wrapper import MinihackTaskWrapper +except ImportError as e: + warnings.warn(f"Unable to import the following minihack dependencies: {e.name}") + +try: + from .nethack_wrappers import NethackTaskWrapper, RenderCharImagesWithNumpyWrapperV2 +except ImportError as e: + warnings.warn(f"Unable to import the following nle dependencies: {e.name}") + +try: + from .procgen_task_wrapper import ProcgenTaskWrapper +except ImportError as e: + warnings.warn(f"Unable to import the following procgen dependencies: {e.name}") diff --git a/examples/task_wrappers/cartpole_task_wrapper.py b/examples/task_wrappers/cartpole_task_wrapper.py new file mode 100644 index 00000000..4e590632 --- /dev/null +++ b/examples/task_wrappers/cartpole_task_wrapper.py @@ -0,0 +1,24 @@ +from gymnasium.spaces import Box + +from syllabus.core import TaskWrapper +from syllabus.task_space import TaskSpace + + +class CartPoleTaskWrapper(TaskWrapper): + def __init__(self, env): + super().__init__(env) + self.task_space = TaskSpace(Box(-0.3, 0.3, shape=(2,))) + self.task = (-0.02, 0.02) + self.total_reward = 0 + + def reset(self, *args, **kwargs): + self.total_reward = 0 + if "new_task" in kwargs: + new_task = kwargs.pop("new_task") + self.task = new_task + return self.env.reset(options={"low": self.task[0], "high": self.task[1]}) + + def _task_completion(self, obs, rew, term, trunc, info) -> float: + # Return percent of optimal reward + self.total_reward += rew + return self.total_reward / 500.0 diff --git a/examples/task_wrappers/minigrid_task_wrapper.py b/examples/task_wrappers/minigrid_task_wrapper.py new file mode 100644 index 00000000..de2e699b --- /dev/null +++ b/examples/task_wrappers/minigrid_task_wrapper.py @@ -0,0 +1,88 @@ +""" Task wrapper that can select a new MiniGrid task on reset. """ +import gymnasium as gym +import numpy as np +from syllabus.core import TaskWrapper +from syllabus.task_space import TaskSpace + + +class MinigridTaskWrapper(TaskWrapper): + """ + This wrapper allows you to change the task of an NLE environment. + """ + def __init__(self, env: gym.Env): + super().__init__(env) + try: + from gym_minigrid.minigrid import COLOR_TO_IDX, OBJECT_TO_IDX + except ImportError: + warnings.warn("Unable to import gym_minigrid.") + + self.observation_space = gym.spaces.Box( + low=0, + high=255, + shape=(self.env.width, self.env.height, 3), # number of cells + dtype='uint8' + ) + m, n, c = self.observation_space.shape + self.observation_space = gym.spaces.Box( + self.observation_space.low[0, 0, 0], + self.observation_space.high[0, 0, 0], + [c, m, n], + dtype=self.observation_space.dtype) + + # Set up task space + self.task_space = TaskSpace(gym.spaces.Discrete(4000), list(np.arange(4000))) + self.task = None + + def reset(self, new_task=None, **kwargs): + """ + Resets the environment along with all available tasks, and change the current task. + + This ensures that all instance variables are reset, not just the ones for the current task. + We do this efficiently by keeping track of which reset functions have already been called, + since very few tasks override reset. If new_task is provided, we change the task before + calling the final reset. + """ + # Change task if new one is provided + if new_task is not None: + self.change_task(new_task) + + self.done = False + self.episode_return = 0 + + return self.observation(self.env.reset(**kwargs)["image"]) + + def change_task(self, new_task: int): + """ + Change task by directly editing environment class. + + Ignores requests for unknown tasks or task changes outside of a reset. + """ + seed = int(new_task) + self.task = seed + self.env.seed(seed) + + def step(self, action): + """ + Step through environment and update task completion. + """ + # assert self._elapsed_steps is not None, "Cannot call env.step() before calling reset()" + obs, rew, term, trunc, info = self.env.step(action) + obs = self.observation(obs["image"]) + + self.episode_return += rew + self.done = term or trunc + info["task_completion"] = self._task_completion(obs, rew, term, trunc, info) + + return obs, rew, term, trunc, info + + def observation(self, obs): + env = self.unwrapped + full_grid = env.grid.encode() + full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([ + OBJECT_TO_IDX['agent'], + COLOR_TO_IDX['red'], + env.agent_dir + ]) + + obs = full_grid + return obs.transpose(2, 0, 1) diff --git a/examples/task_wrappers/minihack_task_wrapper.py b/examples/task_wrappers/minihack_task_wrapper.py new file mode 100644 index 00000000..f7f61be2 --- /dev/null +++ b/examples/task_wrappers/minihack_task_wrapper.py @@ -0,0 +1,30 @@ +""" Task wrapper for NLE that can change tasks at reset using the NLE's task definition format. """ +import gymnasium as gym +from gymnasium import spaces +from syllabus.core import TaskWrapper +from syllabus.task_space import TaskSpace + + +class MinihackTaskWrapper(TaskWrapper): + """ + This wrapper simply changes the seed of a Minigrid environment. + """ + def __init__(self, env: gym.Env): + super().__init__(env) + self.env = env + + self.task: str = 1 + + # Task completion metrics + self.episode_return = 0 + self.task_space = TaskSpace(spaces.Discrete(1000), list(range(1000))) + + def reset(self, new_task: int = None, **kwargs): + # Change task if new one is provided + # if new_task is not None: + # self.change_task(new_task) + + self.episode_return = 0 + self.current_task = new_task + self.env.seed(new_task) + return self.observation(self.env.reset(**kwargs)) diff --git a/examples/task_wrappers/nethack_wrappers.py b/examples/task_wrappers/nethack_wrappers.py new file mode 100644 index 00000000..ba0c2cae --- /dev/null +++ b/examples/task_wrappers/nethack_wrappers.py @@ -0,0 +1,493 @@ +""" Task wrapper for NLE that can change tasks at reset using the NLE's task definition format. """ +import os +import time +from typing import Any, Dict, List, Tuple + +import cv2 +import gymnasium as gym +import numpy as np +# import render_utils +from gymnasium.utils.step_api_compatibility import step_api_compatibility +from nle import nethack +from nle.env import base +from nle.env.tasks import (NetHackChallenge, NetHackEat, NetHackGold, + NetHackOracle, NetHackScore, NetHackScout, + NetHackStaircase, NetHackStaircasePet) +from numba import njit +from PIL import Image, ImageDraw, ImageFont +from shimmy.openai_gym_compatibility import GymV21CompatibilityV0 + +from syllabus.core import TaskWrapper +from syllabus.task_space import TaskSpace + + +class NethackTaskWrapper(TaskWrapper): + """ + This wrapper allows you to change the task of an NLE environment. + + This wrapper was designed to meet two goals. + 1. Allow us to change the task of the NLE environment at the start of an episode + 2. Allow us to use the predefined NLE task definitions without copying/modifying their code. + This makes it easier to integrate with other work on nethack tasks or curricula. + + Each task is defined as a subclass of the NLE, so you need to cast and reinitialize the + environment to change its task. This wrapper manipulates the __class__ property to achieve this, + but does so in a safe way. Specifically, we ensure that the instance variables needed for each + task are available and reset at the start of the episode regardless of which task is active. + """ + def __init__( + self, + env: gym.Env, + additional_tasks: List[base.NLE] = None, + use_default_tasks: bool = True, + env_kwargs: Dict[str, Any] = {}, + wrappers: List[Tuple[gym.Wrapper, List[Any], Dict[str, Any]]] = None + ): + super().__init__(env) + self.env = env + self.task = NetHackScore + self._init_kwargs = env_kwargs + if self.env.__class__ == NetHackChallenge: + self._no_progress_timeout = self._init_kwargs.pop("no_progress_timeout", 150) + + # This is set to False during reset + self.done = True + + # Add nethack tasks provided by the base NLE + task_list: List[base.NLE] = [] + if use_default_tasks: + task_list = [ + NetHackScore, + NetHackStaircase, + NetHackStaircasePet, + NetHackOracle, + NetHackGold, + NetHackEat, + NetHackScout, + ] + + # Add in custom nethack tasks + if additional_tasks: + for task in additional_tasks: + assert isinstance(task, base.NLE), "Env must subclass the base NLE" + task_list.append(task) + + self.task_list = task_list + gym_space = gym.spaces.Discrete(len(self.task_list)) + self.task_space = TaskSpace(gym_space, task_list) + + # Add goal space to observation + # self.observation_space = copy.deepcopy(self.env.observation_space) + # self.observation_space["goal"] = spaces.MultiBinary(len(self.task_list)) + + # Task completion metrics + self.episode_return = 0 + + # TODO: Deal with wrappers + self._nethack_env = self.env + while self._nethack_env.__class__ not in self.task_list and self._nethack_env.__class__ != NetHackChallenge: + if self._nethack_env.__class__ == GymV21CompatibilityV0: + self._nethack_env = self._nethack_env.gym_env + else: + self._nethack_env = self._nethack_env.env + + # Initialize missing instance variables + self._nethack_env.oracle_glyph = None + + def _task_name(self, task): + return task.__name__ + + def reset(self, new_task=None, **kwargs): + """ + Resets the environment along with all available tasks, and change the current task. + + This ensures that all instance variables are reset, not just the ones for the current task. + We do this efficiently by keeping track of which reset functions have already been called, + since very few tasks override reset. If new_task is provided, we change the task before + calling the final reset. + """ + # Change task if new one is provided + new_task = np.random.choice(self.task_list) + if new_task is not None: + self.change_task(new_task) + + self.done = False + self.episode_return = 0 + + return self.observation(self.env.reset(**kwargs)) + + def change_task(self, new_task: int): + """ + Change task by directly editing environment class. + + Ignores requests for unknown tasks or task changes outside of a reset. + """ + # Ignore new task if mid episode + if self.task.__init__ != new_task.__init__ and not self.done: + print(f"Given task {self._task_name(new_task)} needs to be reinitialized.\ + Ignoring request to change task and keeping {self.task.__name__}") + return + + # Ignore if task is unknown + if new_task not in self.task_list: + print(f"Given task {new_task} not in task list.\ + Ignoring request to change task and keeping {self.env.__class__.__name__}") + return + + # Update current task + self.task = new_task + self._nethack_env.__class__ = new_task + + # If task requires reinitialization + # if type(self._nethack_env).__init__ != NetHackScore.__init__: + # self._nethack_env.__init__(actions=nethack.ACTIONS, **self._init_kwargs) + + def _encode_goal(self): + goal_encoding = np.zeros(len(self.task_list)) + index = self.task_list.index(self.task) + goal_encoding[index] = 1 + return goal_encoding + + def observation(self, observation): + """ + Parses current inventory and new items gained this timestep from the observation. + Returns a modified observation. + """ + # Add goal to observation + # observation['goal'] = self._encode_goal() + return observation + + def _task_completion(self, obs, rew, term, trunc, info): + # TODO: Add real task completion metrics + completion = 0.0 + if self.task == 0: + completion = self.episode_return / 1000 + elif self.task == 1: + completion = self.episode_return + elif self.task == 2: + completion = self.episode_return + elif self.task == 3: + completion = self.episode_return + elif self.task == 4: + completion = self.episode_return / 1000 + elif self.task == 5: + completion = self.episode_return / 10 + elif self.task == 6: + completion = self.episode_return / 100 + + return min(max(completion, 0.0), 1.0) + + def step(self, action): + """ + Step through environment and update task completion. + """ + obs, rew, term, trunc, info = step_api_compatibility(self.env.step(action), output_truncation_bool=True) + # self.episode_return += rew + self.done = term or trunc + info["task_completion"] = self._task_completion(obs, rew, term, trunc, info) + return self.observation(obs), rew, term, trunc, info + + +SMALL_FONT_PATH = os.path.abspath("syllabus/examples/utils/Hack-Regular.ttf") + +# Mapping of 0-15 colors used. +# Taken from bottom image here. It seems about right +# https://i.stack.imgur.com/UQVe5.png +COLORS = [ + "#000000", + "#800000", + "#008000", + "#808000", + "#000080", + "#800080", + "#008080", + "#808080", # - flipped these ones around + "#C0C0C0", # | the gray-out dull stuff + "#FF0000", + "#00FF00", + "#FFFF00", + "#0000FF", + "#FF00FF", + "#00FFFF", + "#FFFFFF", +] + + +@njit +def _tile_characters_to_image( + out_image, + chars, + colors, + output_height_chars, + output_width_chars, + char_array, + offset_h, + offset_w, +): + """ + Build an image using cached images of characters in char_array to out_image + """ + char_height = char_array.shape[3] + char_width = char_array.shape[4] + for h in range(output_height_chars): + h_char = h + offset_h + # Stuff outside boundaries is not visible, so + # just leave it black + if h_char < 0 or h_char >= chars.shape[0]: + continue + for w in range(output_width_chars): + w_char = w + offset_w + if w_char < 0 or w_char >= chars.shape[1]: + continue + char = chars[h_char, w_char] + color = colors[h_char, w_char] + h_pixel = h * char_height + w_pixel = w * char_width + out_image[ + :, h_pixel : h_pixel + char_height, w_pixel : w_pixel + char_width + ] = char_array[char, color] + + +def _initialize_char_array(font_size, rescale_font_size): + """Draw all characters in PIL and cache them in numpy arrays + + if rescale_font_size is given, assume it is (width, height) + + Returns a np array of (num_chars, num_colors, char_height, char_width, 3) + """ + try: + font = ImageFont.truetype(SMALL_FONT_PATH, font_size) + except OSError as e: + raise ValueError("Change SMALL_FONT_PATH to point to syllabus/examples/utils/Hack-Regular.ttf") from e + + dummy_text = "".join( + [(chr(i) if chr(i).isprintable() else " ") for i in range(256)] + ) + _, _, image_width, image_height = font.getbbox(dummy_text) + # Above can not be trusted (or its siblings).... + image_width = int(np.ceil(image_width / 256) * 256) + + char_width = rescale_font_size[0] + char_height = rescale_font_size[1] + + char_array = np.zeros((256, 16, char_height, char_width, 3), dtype=np.uint8) + image = Image.new("RGB", (image_width, image_height)) + image_draw = ImageDraw.Draw(image) + for color_index in range(16): + image_draw.rectangle((0, 0, image_width, image_height), fill=(0, 0, 0)) + image_draw.text((0, 0), dummy_text, fill=COLORS[color_index], spacing=0) + + arr = np.array(image).copy() + arrs = np.array_split(arr, 256, axis=1) + for char_index in range(256): + char = arrs[char_index] + if rescale_font_size: + char = cv2.resize(char, rescale_font_size, interpolation=cv2.INTER_AREA) + char_array[char_index, color_index] = char + return char_array + + +class RenderCharImagesWithNumpyWrapper(gym.Wrapper): + """ + Render characters as images, using PIL to render characters like we humans see on screen + but then some caching and numpy stuff to speed up things. + + To speed things up, crop image around the player. + """ + + def __init__( + self, + env, + font_size=9, + crop_size=12, + rescale_font_size=(6, 6), + blstats_cursor=False, + ): + super().__init__(env) + self.char_array = _initialize_char_array(font_size, rescale_font_size) + self.char_height = self.char_array.shape[2] + self.char_width = self.char_array.shape[3] + # Transpose for CHW + self.char_array = self.char_array.transpose(0, 1, 4, 2, 3) + + self.crop_size = crop_size + self.blstats_cursor = blstats_cursor + + self.half_crop_size = crop_size // 2 + self.output_height_chars = crop_size + self.output_width_chars = crop_size + self.chw_image_shape = ( + 3, + self.output_height_chars * self.char_height, + self.output_width_chars * self.char_width, + ) + + obs_spaces = { + "screen_image": gym.spaces.Box( + low=0, high=255, shape=self.chw_image_shape, dtype=np.uint8 + ) + } + obs_spaces.update( + [ + (k, self.env.observation_space[k]) + for k in self.env.observation_space + if k not in ["tty_chars", "tty_colors"] + ] + ) + self.observation_space = gym.spaces.Dict(obs_spaces) + + def _render_text_to_image(self, obs): + chars = obs["tty_chars"] + colors = obs["tty_colors"] + offset_w = 0 + offset_h = 0 + if self.crop_size: + # Center around player + if self.blstats_cursor: + center_x, center_y = obs["blstats"][:2] + else: + center_y, center_x = obs["tty_cursor"] + offset_h = center_y - self.half_crop_size + offset_w = center_x - self.half_crop_size + + out_image = np.zeros(self.chw_image_shape, dtype=np.uint8) + + _tile_characters_to_image( + out_image=out_image, + chars=chars, + colors=colors, + output_height_chars=self.output_height_chars, + output_width_chars=self.output_width_chars, + char_array=self.char_array, + offset_h=offset_h, + offset_w=offset_w, + ) + + obs["screen_image"] = out_image + return obs + + def step(self, action): + obs, reward, done, info = self.env.step(action) + obs = self._render_text_to_image(obs) + return obs, reward, done, info + + def reset(self): + obs = self.env.reset() + obs = self._render_text_to_image(obs) + return obs + + +class RenderCharImagesWithNumpyWrapperV2(gym.Wrapper): + """ + Same as V1, but simpler and faster. + """ + + def __init__( + self, + env, + font_size=9, + crop_size=12, + rescale_font_size=(6, 6), + ): + super().__init__(env) + self.char_array = _initialize_char_array(font_size, rescale_font_size) + self.char_height = self.char_array.shape[2] + self.char_width = self.char_array.shape[3] + # Transpose for CHW + self.char_array = self.char_array.transpose(0, 1, 4, 2, 3) + self.char_array = np.ascontiguousarray(self.char_array) + self.crop_size = crop_size + + crop_rows = crop_size or nethack.nethack.TERMINAL_SHAPE[0] + crop_cols = crop_size or nethack.nethack.TERMINAL_SHAPE[1] + + self.chw_image_shape = ( + 3, + crop_rows * self.char_height, + crop_cols * self.char_width, + ) + + obs_spaces = { + "screen_image": gym.spaces.Box( + low=0, high=255, shape=self.chw_image_shape, dtype=np.uint8 + ) + } + obs_spaces.update( + [ + (k, self.env.observation_space[k]) + for k in self.env.observation_space + # if k not in ["tty_chars", "tty_colors"] + ] + ) + self.observation_space = gym.spaces.Dict(obs_spaces) + + def _populate_obs(self, obs): + screen = np.zeros(self.chw_image_shape, order="C", dtype=np.uint8) + render_utils.render_crop( + obs["tty_chars"], + obs["tty_colors"], + obs["tty_cursor"], + self.char_array, + screen, + crop_size=self.crop_size, + ) + obs["screen_image"] = screen + + def step(self, action): + obs, reward, term, trunc, info = self.env.step(action) + self._populate_obs(obs) + return obs, reward, term, trunc, info + + def reset(self): + obs, info = self.env.reset() + self._populate_obs(obs) + return obs, info + + +if __name__ == "__main__": + def run_episode(env, task: str = None, verbose=1): + env.reset(new_task=task) + task_name = type(env.unwrapped).__name__ + term = trunc = False + ep_rew = 0 + while not (term or trunc): + action = env.action_space.sample() + _, rew, term, trunc, _ = env.step(action) + ep_rew += rew + if verbose: + print(f"Episodic reward for {task_name}: {ep_rew}") + + print("Testing NethackTaskWrapper") + N_EPISODES = 100 + + # Initialize NLE + nethack_env = NetHackScore() + nethack_env = GymV21CompatibilityV0(env=nethack_env) + + nethack_task_env = NethackTaskWrapper(nethack_env) + + task_list = [ + NetHackScore, + NetHackStaircase, + NetHackStaircasePet, + NetHackOracle, + NetHackGold, + NetHackEat, + NetHackScout, + ] + + start_time = time.time() + + for _ in range(N_EPISODES): + run_episode(nethack_task_env, verbose=0) + + end_time = time.time() + print(f"Run time same task: {end_time - start_time}") + start_time = time.time() + + for i in range(N_EPISODES): + nethack_task = task_list[i % 7] + run_episode(nethack_task_env, task=nethack_task, verbose=0) + + end_time = time.time() + print(f"Run time swapping tasks: {end_time - start_time}") diff --git a/examples/task_wrappers/pistonball_task_wrapper.py b/examples/task_wrappers/pistonball_task_wrapper.py new file mode 100644 index 00000000..d60abb7a --- /dev/null +++ b/examples/task_wrappers/pistonball_task_wrapper.py @@ -0,0 +1,35 @@ +""" Task wrapper for NLE that can change tasks at reset using the NLE's task definition format. """ +import gymnasium as gym +from gymnasium import spaces +from pettingzoo.butterfly import pistonball_v6 +from syllabus.core import PettingZooTaskWrapper +from syllabus.task_space import TaskSpace + + +class PistonballTaskWrapper(PettingZooTaskWrapper): + """ + This wrapper simply changes the seed of a Minigrid environment. + """ + def __init__(self, env: gym.Env): + super().__init__(env) + self.env = env + self.env.unwrapped.task: str = 1 + + # Task completion metrics + self.episode_return = 0 + self.task_space = TaskSpace(spaces.Discrete(11), list(range(11))) # 0.1 - 1.0 friction + + def reset(self, new_task: int = None, **kwargs): + # Change task if new one is provided + # if new_task is not None: + # self.change_task(new_task) + + self.episode_return = 0 + if new_task is not None: + task = new_task / 10 + # Inject current_task into the environment + self.env = pistonball_v6.parallel_env( + ball_friction=task, continuous=False, max_cycles=125 + ) + self.env.unwrapped.task = new_task + return self.observation(self.env.reset(**kwargs)) diff --git a/examples/task_wrappers/procgen_task_wrapper.py b/examples/task_wrappers/procgen_task_wrapper.py new file mode 100644 index 00000000..2296fd58 --- /dev/null +++ b/examples/task_wrappers/procgen_task_wrapper.py @@ -0,0 +1,86 @@ +import gymnasium as gym +import numpy as np +from syllabus.core import TaskWrapper +from syllabus.task_space import TaskSpace + + +PROCGEN_RETURN_BOUNDS = { + "coinrun": (5, 10), + "starpilot": (2.5, 64), + "caveflyer": (3.5, 12), + "dodgeball": (1.5, 19), + "fruitbot": (-1.5, 32.4), + "chaser": (0.5, 13), + "miner": (1.5, 13), + "jumper": (3, 10), + "leaper": (3, 10), + "maze": (5, 10), + "bigfish": (1, 40), + "heist": (3.5, 10), + "climber": (2, 12.6), + "plunder": (4.5, 30), + "ninja": (3.5, 10), + "bossfight": (0.5, 13), +} + + +class ProcgenTaskWrapper(TaskWrapper): + """ + This wrapper allows you to change the task of an NLE environment. + """ + def __init__(self, env: gym.Env, env_id, seed=0): + super().__init__(env) + self.task_space = TaskSpace(gym.spaces.Discrete(200), list(np.arange(0, 200))) + self.env_id = env_id + self.task = seed + self.seed(seed) + self.episode_return = 0 + + self.observation_space = self.env.observation_space + + def seed(self, seed): + self.env.gym_env.unwrapped._venv.seed(int(seed), 0) + + def reset(self, new_task=None, **kwargs): + """ + Resets the environment along with all available tasks, and change the current task. + + This ensures that all instance variables are reset, not just the ones for the current task. + We do this efficiently by keeping track of which reset functions have already been called, + since very few tasks override reset. If new_task is provided, we change the task before + calling the final reset. + """ + self.episode_return = 0.0 + + # Change task if new one is provided + if new_task is not None: + self.change_task(new_task) + + obs, info = self.env.reset(**kwargs) + return self.observation(obs), info + + def change_task(self, new_task: int): + """ + Change task by directly editing environment class. + + Ignores requests for unknown tasks or task changes outside of a reset. + """ + seed = int(new_task) + self.task = seed + self.seed(seed) + + def step(self, action): + """ + Step through environment and update task completion. + """ + obs, rew, term, trunc, info = self.env.step(action) + self.episode_return += rew + + env_min, env_max = PROCGEN_RETURN_BOUNDS[self.env_id] + normalized_return = (self.episode_return - env_min) / (env_max - env_min) + info["task_completion"] = normalized_return + + return self.observation(obs), rew, term, trunc, info + + def observation(self, obs): + return obs diff --git a/examples/training_scripts/cleanrl_procgen_centralplr.py b/examples/training_scripts/cleanrl_procgen_centralplr.py new file mode 100644 index 00000000..b848d693 --- /dev/null +++ b/examples/training_scripts/cleanrl_procgen_centralplr.py @@ -0,0 +1,536 @@ +""" An example applying Syllabus Prioritized Level Replay to Procgen. This code is based on https://github.com/facebookresearch/level-replay/blob/main/train.py + +NOTE: In order to efficiently change the seed of a procgen environment directly without reinitializing it, +we rely on Minqi Jiang's custom branch of procgen found here: https://github.com/minqi/procgen +""" +import argparse +import os +import random +import time +from collections import deque +from distutils.util import strtobool + +import gym as openai_gym +import gymnasium as gym +import numpy as np +import procgen # noqa: F401 +from procgen import ProcgenEnv +import torch +import torch.nn as nn +import torch.optim as optim +from shimmy.openai_gym_compatibility import GymV21CompatibilityV0 +from torch.utils.tensorboard import SummaryWriter + +from syllabus.core import MultiProcessingSyncWrapper, make_multiprocessing_curriculum +from syllabus.curricula import CentralizedPrioritizedLevelReplay, DomainRandomization, LearningProgressCurriculum, SequentialCurriculum +from syllabus.examples.models import ProcgenAgent +from syllabus.examples.task_wrappers import ProcgenTaskWrapper +from syllabus.examples.utils.vecenv import VecMonitor, VecNormalize, VecExtractDictObs + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="syllabus", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="weather to capture videos of the agent performances (check out `videos` folder)") + parser.add_argument("--logging-dir", type=str, default=".", + help="the base directory for logging and wandb storage.") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="starpilot", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=int(25e6), + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=5e-4, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=64, + help="the number of parallel game environments") + parser.add_argument("--num-steps", type=int, default=256, + help="the number of steps to run in each environment per policy rollout") + parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Use GAE for advantage computation") + parser.add_argument("--gamma", type=float, default=0.999, + help="the discount factor gamma") + parser.add_argument("--gae-lambda", type=float, default=0.95, + help="the lambda for the general advantage estimation") + parser.add_argument("--num-minibatches", type=int, default=8, + help="the number of mini-batches") + parser.add_argument("--update-epochs", type=int, default=3, + help="the K epochs to update the policy") + parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles advantages normalization") + parser.add_argument("--clip-coef", type=float, default=0.2, + help="the surrogate clipping coefficient") + parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--ent-coef", type=float, default=0.01, + help="coefficient of the entropy") + parser.add_argument("--vf-coef", type=float, default=0.5, + help="coefficient of the value function") + parser.add_argument("--max-grad-norm", type=float, default=0.5, + help="the maximum norm for the gradient clipping") + parser.add_argument("--target-kl", type=float, default=None, + help="the target KL divergence threshold") + + # Procgen arguments + parser.add_argument("--full-dist", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Train on full distribution of levels.") + + # Curriculum arguments + parser.add_argument("--curriculum", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will use curriculum learning") + parser.add_argument("--curriculum-method", type=str, default="plr", + help="curriculum method to use") + parser.add_argument("--num-eval-episodes", type=int, default=10, + help="the number of episodes to evaluate the agent on after each policy update.") + + args = parser.parse_args() + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + # fmt: on + return args + + +PROCGEN_RETURN_BOUNDS = { + "coinrun": (5, 10), + "starpilot": (2.5, 64), + "caveflyer": (3.5, 12), + "dodgeball": (1.5, 19), + "fruitbot": (-1.5, 32.4), + "chaser": (0.5, 13), + "miner": (1.5, 13), + "jumper": (3, 10), + "leaper": (3, 10), + "maze": (5, 10), + "bigfish": (1, 40), + "heist": (3.5, 10), + "climber": (2, 12.6), + "plunder": (4.5, 30), + "ninja": (3.5, 10), + "bossfight": (0.5, 13), +} + + +def make_env(env_id, seed, curriculum=None, start_level=0, num_levels=1): + def thunk(): + env = openai_gym.make(f"procgen-{env_id}-v0", distribution_mode="easy", start_level=start_level, num_levels=num_levels) + env = GymV21CompatibilityV0(env=env) + if curriculum is not None: + env = ProcgenTaskWrapper(env, env_id, seed=seed) + env = MultiProcessingSyncWrapper( + env, + curriculum.get_components(), + update_on_step=False, + task_space=env.task_space, + ) + return env + return thunk + + +def wrap_vecenv(vecenv): + vecenv.is_vector_env = True + vecenv = VecMonitor(venv=vecenv, filename=None, keep_buf=100) + vecenv = VecNormalize(venv=vecenv, ob=False, ret=True) + return vecenv + + +def full_level_replay_evaluate( + env_name, + policy, + num_episodes, + device, + num_levels=1 # Not used +): + policy.eval() + + eval_envs = ProcgenEnv( + num_envs=args.num_eval_episodes, env_name=env_name, num_levels=1, start_level=0, distribution_mode="easy", paint_vel_info=False + ) + eval_envs = VecExtractDictObs(eval_envs, "rgb") + eval_envs = wrap_vecenv(eval_envs) + + # Seed environments + seeds = [int.from_bytes(os.urandom(3), byteorder="little") for _ in range(num_episodes)] + for i, seed in enumerate(seeds): + eval_envs.seed(seed, i) + + eval_obs, _ = eval_envs.reset() + eval_episode_rewards = [-1] * num_episodes + + while -1 in eval_episode_rewards: + with torch.no_grad(): + eval_action, _, _, _ = policy.get_action_and_value(torch.Tensor(eval_obs).to(device), deterministic=False) + + eval_obs, _, truncs, terms, infos = eval_envs.step(eval_action.cpu().numpy()) + for i, info in enumerate(infos): + if 'episode' in info.keys() and eval_episode_rewards[i] == -1: + eval_episode_rewards[i] = info['episode']['r'] + + mean_returns = np.mean(eval_episode_rewards) + stddev_returns = np.std(eval_episode_rewards) + env_min, env_max = PROCGEN_RETURN_BOUNDS[args.env_id] + normalized_mean_returns = (mean_returns - env_min) / (env_max - env_min) + policy.train() + return mean_returns, stddev_returns, normalized_mean_returns + + +def level_replay_evaluate( + env_name, + policy, + num_episodes, + device, + num_levels=0 +): + policy.eval() + + eval_envs = ProcgenEnv( + num_envs=args.num_eval_episodes, env_name=env_name, num_levels=num_levels, start_level=0, distribution_mode="easy", paint_vel_info=False + ) + eval_envs = VecExtractDictObs(eval_envs, "rgb") + eval_envs = wrap_vecenv(eval_envs) + eval_obs, _ = eval_envs.reset() + eval_episode_rewards = [-1] * num_episodes + + while -1 in eval_episode_rewards: + with torch.no_grad(): + eval_action, _, _, _ = policy.get_action_and_value(torch.Tensor(eval_obs).to(device), deterministic=False) + + eval_obs, _, truncs, terms, infos = eval_envs.step(eval_action.cpu().numpy()) + for i, info in enumerate(infos): + if 'episode' in info.keys() and eval_episode_rewards[i] == -1: + eval_episode_rewards[i] = info['episode']['r'] + + # print(eval_episode_rewards) + mean_returns = np.mean(eval_episode_rewards) + stddev_returns = np.std(eval_episode_rewards) + env_min, env_max = PROCGEN_RETURN_BOUNDS[args.env_id] + normalized_mean_returns = (mean_returns - env_min) / (env_max - env_min) + policy.train() + return mean_returns, stddev_returns, normalized_mean_returns + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + dir=args.logging_dir + ) + # wandb.run.log_code("./syllabus/examples") + + writer = SummaryWriter(os.path.join(args.logging_dir, "./runs/{run_name}")) + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + print("Device:", device) + + # Curriculum setup + curriculum = None + if args.curriculum: + sample_env = openai_gym.make(f"procgen-{args.env_id}-v0") + sample_env = GymV21CompatibilityV0(env=sample_env) + sample_env = ProcgenTaskWrapper(sample_env, args.env_id, seed=args.seed) + + # Intialize Curriculum Method + if args.curriculum_method == "plr": + print("Using prioritized level replay.") + curriculum = CentralizedPrioritizedLevelReplay( + sample_env.task_space, + num_steps=args.num_steps, + num_processes=args.num_envs, + gamma=args.gamma, + gae_lambda=args.gae_lambda, + task_sampler_kwargs_dict={"strategy": "value_l1"} + ) + elif args.curriculum_method == "dr": + print("Using domain randomization.") + curriculum = DomainRandomization(sample_env.task_space) + elif args.curriculum_method == "lp": + print("Using learning progress.") + curriculum = LearningProgressCurriculum(sample_env.task_space) + elif args.curriculum_method == "sq": + print("Using sequential curriculum.") + curricula = [] + stopping = [] + for i in range(199): + curricula.append(i + 1) + stopping.append("steps>=50000") + curricula.append(list(range(i + 1))) + stopping.append("steps>=50000") + curriculum = SequentialCurriculum(curricula, stopping[:-1], sample_env.task_space) + else: + raise ValueError(f"Unknown curriculum method {args.curriculum_method}") + curriculum = make_multiprocessing_curriculum(curriculum) + del sample_env + + # env setup + print("Creating env") + envs = gym.vector.AsyncVectorEnv( + [ + make_env( + args.env_id, + args.seed + i, + curriculum=curriculum if args.curriculum else None, + num_levels=1 if args.curriculum else 0 + ) + for i in range(args.num_envs) + ] + ) + envs = wrap_vecenv(envs) + + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + print("Creating agent") + agent = ProcgenAgent( + envs.single_observation_space.shape, + envs.single_action_space.n, + arch="large", + base_kwargs={'recurrent': False, 'hidden_size': 256} + ).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs, _ = envs.reset() + next_obs = torch.Tensor(next_obs).to(device) + next_done = torch.zeros(args.num_envs).to(device) + num_updates = args.total_timesteps // args.batch_size + episode_rewards = deque(maxlen=10) + completed_episodes = 0 + + for update in range(1, num_updates + 1): + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (update - 1.0) / num_updates + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += 1 * args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value = agent.get_action_and_value(next_obs) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, term, trunc, info = envs.step(action.cpu().numpy()) + done = np.logical_or(term, trunc) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) + completed_episodes += sum(done) + + for item in info: + if "episode" in item.keys(): + episode_rewards.append(item['episode']['r']) + print(f"global_step={global_step}, episodic_return={item['episode']['r']}") + writer.add_scalar("charts/episodic_return", item["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step) + if curriculum is not None: + curriculum.log_metrics(writer, global_step) + break + + # Syllabus curriculum update + if args.curriculum and args.curriculum_method == "plr": + with torch.no_grad(): + next_value = agent.get_value(next_obs) + tasks = envs.get_attr("task") + + update = { + "update_type": "on_demand", + "metrics": { + "value": value, + "next_value": next_value, + "rew": reward, + "dones": done, + "tasks": tasks, + }, + } + curriculum.update(update) + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + if args.gae: + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values + else: + returns = torch.zeros_like(rewards).to(device) + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + next_return = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + next_return = returns[t + 1] + returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return + advantages = returns - values + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None: + if approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # Evaluate agent + mean_eval_returns, stddev_eval_returns, normalized_mean_eval_returns = level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=0 + ) + full_mean_eval_returns, full_stddev_eval_returns, full_normalized_mean_eval_returns = full_level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=0 + ) + mean_train_returns, stddev_train_returns, normalized_mean_train_returns = level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=200 + ) + full_mean_train_returns, full_stddev_train_returns, full_normalized_mean_train_returns = full_level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=200 + ) + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("charts/episode_returns", np.mean(episode_rewards), global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + writer.add_scalar("test_eval/mean_episode_return", mean_eval_returns, global_step) + writer.add_scalar("test_eval/normalized_mean_eval_return", normalized_mean_eval_returns, global_step) + writer.add_scalar("test_eval/stddev_eval_return", stddev_eval_returns, global_step) + writer.add_scalar("test_eval/full_mean_episode_return", full_mean_eval_returns, global_step) + writer.add_scalar("test_eval/full_normalized_mean_eval_return", full_normalized_mean_eval_returns, global_step) + writer.add_scalar("test_eval/full_stddev_eval_return", full_stddev_eval_returns, global_step) + + writer.add_scalar("train_eval/mean_episode_return", mean_train_returns, global_step) + writer.add_scalar("train_eval/normalized_mean_train_return", normalized_mean_train_returns, global_step) + writer.add_scalar("train_eval/stddev_train_return", stddev_train_returns, global_step) + writer.add_scalar("train_eval/full_mean_episode_return", full_mean_train_returns, global_step) + writer.add_scalar("train_eval/full_normalized_mean_train_return", full_normalized_mean_train_returns, global_step) + writer.add_scalar("train_eval/full_stddev_train_return", full_stddev_train_returns, global_step) + + writer.add_scalar("curriculum/completed_episodes", completed_episodes, step) + + envs.close() + writer.close() diff --git a/examples/training_scripts/cleanrl_procgen_plr.py b/examples/training_scripts/cleanrl_procgen_plr.py new file mode 100644 index 00000000..dabcd500 --- /dev/null +++ b/examples/training_scripts/cleanrl_procgen_plr.py @@ -0,0 +1,528 @@ +""" An example applying Syllabus Prioritized Level Replay to Procgen. This code is based on https://github.com/facebookresearch/level-replay/blob/main/train.py + +NOTE: In order to efficiently change the seed of a procgen environment directly without reinitializing it, +we rely on Minqi Jiang's custom branch of procgen found here: https://github.com/minqi/procgen +""" +import argparse +import os +import random +import time +from collections import deque +from distutils.util import strtobool + +import gym as openai_gym +import gymnasium as gym +import numpy as np +import procgen # noqa: F401 +from procgen import ProcgenEnv +import torch +import torch.nn as nn +import torch.optim as optim +from shimmy.openai_gym_compatibility import GymV21CompatibilityV0 +from torch.utils.tensorboard import SummaryWriter + +from syllabus.core import MultiProcessingSyncWrapper, make_multiprocessing_curriculum +from syllabus.curricula import PrioritizedLevelReplay, DomainRandomization, LearningProgressCurriculum, SequentialCurriculum +from syllabus.examples.models import ProcgenAgent +from syllabus.examples.task_wrappers import ProcgenTaskWrapper +from syllabus.examples.utils.vecenv import VecMonitor, VecNormalize, VecExtractDictObs + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="syllabus", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="weather to capture videos of the agent performances (check out `videos` folder)") + parser.add_argument("--logging-dir", type=str, default=".", + help="the base directory for logging and wandb storage.") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="starpilot", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=int(25e6), + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=5e-4, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=64, + help="the number of parallel game environments") + parser.add_argument("--num-steps", type=int, default=256, + help="the number of steps to run in each environment per policy rollout") + parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--gae", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Use GAE for advantage computation") + parser.add_argument("--gamma", type=float, default=0.999, + help="the discount factor gamma") + parser.add_argument("--gae-lambda", type=float, default=0.95, + help="the lambda for the general advantage estimation") + parser.add_argument("--num-minibatches", type=int, default=8, + help="the number of mini-batches") + parser.add_argument("--update-epochs", type=int, default=3, + help="the K epochs to update the policy") + parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles advantages normalization") + parser.add_argument("--clip-coef", type=float, default=0.2, + help="the surrogate clipping coefficient") + parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles whether or not to use a clipped loss for the value function, as per the paper.") + parser.add_argument("--ent-coef", type=float, default=0.01, + help="coefficient of the entropy") + parser.add_argument("--vf-coef", type=float, default=0.5, + help="coefficient of the value function") + parser.add_argument("--max-grad-norm", type=float, default=0.5, + help="the maximum norm for the gradient clipping") + parser.add_argument("--target-kl", type=float, default=None, + help="the target KL divergence threshold") + + # Procgen arguments + parser.add_argument("--full-dist", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Train on full distribution of levels.") + + # Curriculum arguments + parser.add_argument("--curriculum", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will use curriculum learning") + parser.add_argument("--curriculum-method", type=str, default="plr", + help="curriculum method to use") + parser.add_argument("--num-eval-episodes", type=int, default=10, + help="the number of episodes to evaluate the agent on after each policy update.") + + args = parser.parse_args() + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + # fmt: on + return args + + +PROCGEN_RETURN_BOUNDS = { + "coinrun": (5, 10), + "starpilot": (2.5, 64), + "caveflyer": (3.5, 12), + "dodgeball": (1.5, 19), + "fruitbot": (-1.5, 32.4), + "chaser": (0.5, 13), + "miner": (1.5, 13), + "jumper": (3, 10), + "leaper": (3, 10), + "maze": (5, 10), + "bigfish": (1, 40), + "heist": (3.5, 10), + "climber": (2, 12.6), + "plunder": (4.5, 30), + "ninja": (3.5, 10), + "bossfight": (0.5, 13), +} + + +def make_env(env_id, seed, curriculum=None, start_level=0, num_levels=1): + def thunk(): + env = openai_gym.make(f"procgen-{env_id}-v0", distribution_mode="easy", start_level=start_level, num_levels=num_levels) + env = GymV21CompatibilityV0(env=env) + if curriculum is not None: + env = ProcgenTaskWrapper(env, env_id, seed=seed) + env = MultiProcessingSyncWrapper( + env, + curriculum.get_components(), + update_on_step=curriculum.requires_step_updates, + task_space=env.task_space, + ) + return env + return thunk + + +def wrap_vecenv(vecenv): + vecenv.is_vector_env = True + vecenv = VecMonitor(venv=vecenv, filename=None, keep_buf=100) + vecenv = VecNormalize(venv=vecenv, ob=False, ret=True) + return vecenv + + +def full_level_replay_evaluate( + env_name, + policy, + num_episodes, + device, + num_levels=1 # Not used +): + policy.eval() + + eval_envs = ProcgenEnv( + num_envs=args.num_eval_episodes, env_name=env_name, num_levels=1, start_level=0, distribution_mode="easy", paint_vel_info=False + ) + eval_envs = VecExtractDictObs(eval_envs, "rgb") + eval_envs = wrap_vecenv(eval_envs) + + # Seed environments + seeds = [int.from_bytes(os.urandom(3), byteorder="little") for _ in range(num_episodes)] + for i, seed in enumerate(seeds): + eval_envs.seed(seed, i) + + eval_obs, _ = eval_envs.reset() + eval_episode_rewards = [-1] * num_episodes + + while -1 in eval_episode_rewards: + with torch.no_grad(): + eval_action, _, _, _ = policy.get_action_and_value(torch.Tensor(eval_obs).to(device), deterministic=False) + + eval_obs, _, truncs, terms, infos = eval_envs.step(eval_action.cpu().numpy()) + for i, info in enumerate(infos): + if 'episode' in info.keys() and eval_episode_rewards[i] == -1: + eval_episode_rewards[i] = info['episode']['r'] + + mean_returns = np.mean(eval_episode_rewards) + stddev_returns = np.std(eval_episode_rewards) + env_min, env_max = PROCGEN_RETURN_BOUNDS[args.env_id] + normalized_mean_returns = (mean_returns - env_min) / (env_max - env_min) + policy.train() + return mean_returns, stddev_returns, normalized_mean_returns + + +def level_replay_evaluate( + env_name, + policy, + num_episodes, + device, + num_levels=0 +): + policy.eval() + + eval_envs = ProcgenEnv( + num_envs=args.num_eval_episodes, env_name=env_name, num_levels=num_levels, start_level=0, distribution_mode="easy", paint_vel_info=False + ) + eval_envs = VecExtractDictObs(eval_envs, "rgb") + eval_envs = wrap_vecenv(eval_envs) + eval_obs, _ = eval_envs.reset() + eval_episode_rewards = [-1] * num_episodes + + while -1 in eval_episode_rewards: + with torch.no_grad(): + eval_action, _, _, _ = policy.get_action_and_value(torch.Tensor(eval_obs).to(device), deterministic=False) + + eval_obs, _, truncs, terms, infos = eval_envs.step(eval_action.cpu().numpy()) + for i, info in enumerate(infos): + if 'episode' in info.keys() and eval_episode_rewards[i] == -1: + eval_episode_rewards[i] = info['episode']['r'] + + # print(eval_episode_rewards) + mean_returns = np.mean(eval_episode_rewards) + stddev_returns = np.std(eval_episode_rewards) + env_min, env_max = PROCGEN_RETURN_BOUNDS[args.env_id] + normalized_mean_returns = (mean_returns - env_min) / (env_max - env_min) + policy.train() + return mean_returns, stddev_returns, normalized_mean_returns + + +def make_value_fn(): + def get_value(obs): + obs = np.array(obs) + with torch.no_grad(): + return agent.get_value(torch.Tensor(obs).to(device)) + return get_value + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + dir=args.logging_dir + ) + # wandb.run.log_code("./syllabus/examples") + + writer = SummaryWriter(os.path.join(args.logging_dir, "./runs/{run_name}")) + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + print("Device:", device) + + # Curriculum setup + curriculum = None + if args.curriculum: + sample_env = openai_gym.make(f"procgen-{args.env_id}-v0") + sample_env = GymV21CompatibilityV0(env=sample_env) + sample_env = ProcgenTaskWrapper(sample_env, args.env_id, seed=args.seed) + + # Intialize Curriculum Method + if args.curriculum_method == "plr": + print("Using prioritized level replay.") + curriculum = PrioritizedLevelReplay( + sample_env.task_space, + sample_env.observation_space, + num_steps=args.num_steps, + num_processes=args.num_envs, + gamma=args.gamma, + gae_lambda=args.gae_lambda, + task_sampler_kwargs_dict={"strategy": "value_l1"}, + get_value=make_value_fn(), + ) + elif args.curriculum_method == "dr": + print("Using domain randomization.") + curriculum = DomainRandomization(sample_env.task_space) + elif args.curriculum_method == "lp": + print("Using learning progress.") + curriculum = LearningProgressCurriculum(sample_env.task_space) + elif args.curriculum_method == "sq": + print("Using sequential curriculum.") + curricula = [] + stopping = [] + for i in range(199): + curricula.append(i + 1) + stopping.append("steps>=50000") + curricula.append(list(range(i + 1))) + stopping.append("steps>=50000") + curriculum = SequentialCurriculum(curricula, stopping[:-1], sample_env.task_space) + else: + raise ValueError(f"Unknown curriculum method {args.curriculum_method}") + curriculum = make_multiprocessing_curriculum(curriculum) + del sample_env + + # env setup + print("Creating env") + envs = gym.vector.AsyncVectorEnv( + [ + make_env( + args.env_id, + args.seed + i, + curriculum=curriculum if args.curriculum else None, + num_levels=1 if args.curriculum else 0 + ) + for i in range(args.num_envs) + ] + ) + envs = wrap_vecenv(envs) + + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + print("Creating agent") + agent = ProcgenAgent( + envs.single_observation_space.shape, + envs.single_action_space.n, + arch="large", + base_kwargs={'recurrent': False, 'hidden_size': 256} + ).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs, _ = envs.reset() + next_obs = torch.Tensor(next_obs).to(device) + next_done = torch.zeros(args.num_envs).to(device) + num_updates = args.total_timesteps // args.batch_size + episode_rewards = deque(maxlen=10) + completed_episodes = 0 + + for update in range(1, num_updates + 1): + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (update - 1.0) / num_updates + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += 1 * args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value = agent.get_action_and_value(next_obs) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, term, trunc, info = envs.step(action.cpu().numpy()) + done = np.logical_or(term, trunc) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) + completed_episodes += sum(done) + + for item in info: + if "episode" in item.keys(): + episode_rewards.append(item['episode']['r']) + print(f"global_step={global_step}, episodic_return={item['episode']['r']}") + writer.add_scalar("charts/episodic_return", item["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step) + if curriculum is not None: + curriculum.log_metrics(writer, global_step) + break + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + if args.gae: + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values + else: + returns = torch.zeros_like(rewards).to(device) + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + next_return = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + next_return = returns[t + 1] + returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return + advantages = returns - values + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None: + if approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # Evaluate agent + mean_eval_returns, stddev_eval_returns, normalized_mean_eval_returns = level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=0 + ) + full_mean_eval_returns, full_stddev_eval_returns, full_normalized_mean_eval_returns = full_level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=0 + ) + mean_train_returns, stddev_train_returns, normalized_mean_train_returns = level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=200 + ) + full_mean_train_returns, full_stddev_train_returns, full_normalized_mean_train_returns = full_level_replay_evaluate( + args.env_id, agent, args.num_eval_episodes, device, num_levels=200 + ) + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("charts/episode_returns", np.mean(episode_rewards), global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + writer.add_scalar("test_eval/mean_episode_return", mean_eval_returns, global_step) + writer.add_scalar("test_eval/normalized_mean_eval_return", normalized_mean_eval_returns, global_step) + writer.add_scalar("test_eval/stddev_eval_return", stddev_eval_returns, global_step) + writer.add_scalar("test_eval/full_mean_episode_return", full_mean_eval_returns, global_step) + writer.add_scalar("test_eval/full_normalized_mean_eval_return", full_normalized_mean_eval_returns, global_step) + writer.add_scalar("test_eval/full_stddev_eval_return", full_stddev_eval_returns, global_step) + + writer.add_scalar("train_eval/mean_episode_return", mean_train_returns, global_step) + writer.add_scalar("train_eval/normalized_mean_train_return", normalized_mean_train_returns, global_step) + writer.add_scalar("train_eval/stddev_train_return", stddev_train_returns, global_step) + writer.add_scalar("train_eval/full_mean_episode_return", full_mean_train_returns, global_step) + writer.add_scalar("train_eval/full_normalized_mean_train_return", full_normalized_mean_train_returns, global_step) + writer.add_scalar("train_eval/full_stddev_train_return", full_stddev_train_returns, global_step) + + writer.add_scalar("curriculum/completed_episodes", completed_episodes, step) + + envs.close() + writer.close() diff --git a/examples/utils/__init__.py b/examples/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/utils/vecenv.py b/examples/utils/vecenv.py new file mode 100644 index 00000000..af3b187a --- /dev/null +++ b/examples/utils/vecenv.py @@ -0,0 +1,308 @@ +import time +from collections import deque + +import numpy as np + + +class VecEnv: + """ + An abstract asynchronous, vectorized environment. + Used to batch data from multiple copies of an environment, so that + each observation becomes an batch of observations, and expected action is a batch of actions to + be applied per-environment. + """ + closed = False + viewer = None + + metadata = { + 'render.modes': ['human', 'rgb_array'] + } + + def __init__(self, num_envs, observation_space, action_space): + self.num_envs = num_envs + self.observation_space = observation_space + self.action_space = action_space + + def reset(self): + """ + Reset all the environments and return an array of + observations, or a dict of observation arrays. + + If step_async is still doing work, that work will + be cancelled and step_wait() should not be called + until step_async() is invoked again. + """ + pass + + def step_async(self, actions): + """ + Tell all the environments to start taking a step + with the given actions. + Call step_wait() to get the results of the step. + + You should not call this if a step_async run is + already pending. + """ + pass + + def step_wait(self): + """ + Wait for the step taken with step_async(). + + Returns (obs, rews, dones, infos): + - obs: an array of observations, or a dict of + arrays of observations. + - rews: an array of rewards + - dones: an array of "episode done" booleans + - infos: a sequence of info objects + """ + pass + + def close_extras(self): + """ + Clean up the extra resources, beyond what's in this base class. + Only runs when not self.closed. + """ + pass + + def close(self): + if self.closed: + return + if self.viewer is not None: + self.viewer.close() + self.close_extras() + self.closed = True + + def step(self, actions): + """ + Step the environments synchronously. + + This is available for backwards compatibility. + """ + self.step_async(actions) + return self.step_wait() + + def step_env(self, actions, reset_random=False): + if reset_random: + self.step_env_reset_random_async(actions) + else: + self.step_env_async(actions) + return self.step_wait() + + def render(self, mode='human'): + raise NotImplementedError + + def get_images(self): + """ + Return RGB images from each environment + """ + raise NotImplementedError + + @property + def unwrapped(self): + if isinstance(self, VecEnvWrapper): + return self.venv.unwrapped + else: + return self + + def get_viewer(self): + if self.viewer is None: + from gym.envs.classic_control import rendering + self.viewer = rendering.SimpleImageViewer() + return self.viewer + + +class VecEnvWrapper(VecEnv): + """ + An environment wrapper that applies to an entire batch + of environments at once. + """ + + def __init__(self, venv, observation_space=None, action_space=None): + self.venv = venv + VecEnv.__init__(self, num_envs=venv.num_envs, + observation_space=observation_space or venv.observation_space, + action_space=action_space or venv.action_space) + + def step_async(self, actions): + self.venv.step_async(actions) + + def reset(self): + pass + + def step_wait(self): + pass + + def close(self): + return self.venv.close() + + def render(self, mode='human'): + return self.venv.render(mode=mode) + + def get_images(self): + return self.venv.get_images() + + def __getattr__(self, name): + if name.startswith('_'): + raise AttributeError("attempted to get missing private attribute '{}'".format(name)) + return getattr(self.venv, name) + + +class VecEnvObservationWrapper(VecEnvWrapper): + def process(self, obs): + pass + + def reset(self): + outputs = self.venv.reset() + if len(outputs) == 2: + obs, infos = outputs + else: + obs, infos = outputs, {} + return self.process(obs), infos + + def step_wait(self): + env_outputs = self.venv.step_wait() + if len(env_outputs) == 4: + obs, rews, terms, infos = env_outputs + truncs = np.zeros_like(terms) + else: + obs, rews, terms, truncs, infos = env_outputs + return self.process(obs), rews, terms, truncs, infos + + +class VecExtractDictObs(VecEnvObservationWrapper): + def __init__(self, venv, key): + self.key = key + super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key]) + + def process(self, obs): + return obs[self.key] + + +class VecNormalize(VecEnvWrapper): + """ + A vectorized wrapper that normalizes the observations + and returns from an environment. + """ + + def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8, use_tf=False): + VecEnvWrapper.__init__(self, venv) + self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None + self.ret_rms = RunningMeanStd(shape=()) if ret else None + self.clipob = clipob + self.cliprew = cliprew + self.ret = np.zeros(self.num_envs) + self.gamma = gamma + self.epsilon = epsilon + + def step_wait(self): + obs, rews, terms, truncs, infos = self.venv.step_wait() + news = np.logical_or(terms, truncs) + self.ret = self.ret * self.gamma + rews + obs = self._obfilt(obs) + if self.ret_rms: + self.ret_rms.update(self.ret) + rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew) + self.ret[news] = 0. + return obs, rews, terms, truncs, infos + + def _obfilt(self, obs): + if self.ob_rms: + self.ob_rms.update(obs) + obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob) + return obs + else: + return obs + + def reset(self, seed=None): + self.ret = np.zeros(self.num_envs) + if seed is not None: + obs, infos = self.venv.reset(seed=seed) + else: + obs, infos = self.venv.reset() + return self._obfilt(obs), infos + + +class VecMonitor(VecEnvWrapper): + def __init__(self, venv, filename=None, keep_buf=0, info_keywords=()): + VecEnvWrapper.__init__(self, venv) + self.eprets = None + self.eplens = None + self.epcount = 0 + self.tstart = time.time() + self.results_writer = None + self.info_keywords = info_keywords + self.keep_buf = keep_buf + if self.keep_buf: + self.epret_buf = deque([], maxlen=keep_buf) + self.eplen_buf = deque([], maxlen=keep_buf) + + def reset(self, seed=None): + if seed is not None: + obs, infos = self.venv.reset(seed=seed) + else: + obs, infos = self.venv.reset() + self.eprets = np.zeros(self.num_envs, 'f') + self.eplens = np.zeros(self.num_envs, 'i') + return obs, infos + + def step_wait(self): + obs, rews, terms, truncs, infos = self.venv.step_wait() + dones = np.logical_or(terms, truncs) + self.eprets += rews + self.eplens += 1 + # Convert dict of lists to list of dicts + if isinstance(infos, dict): + infos = [dict(zip(infos, t)) for t in zip(*infos.values())] + newinfos = list(infos[:]) + for i in range(len(dones)): + if dones[i]: + info = infos[i].copy() + ret = self.eprets[i] + eplen = self.eplens[i] + epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)} + for k in self.info_keywords: + epinfo[k] = info[k] + info['episode'] = epinfo + if self.keep_buf: + self.epret_buf.append(ret) + self.eplen_buf.append(eplen) + self.epcount += 1 + self.eprets[i] = 0 + self.eplens[i] = 0 + if self.results_writer: + self.results_writer.write_row(epinfo) + newinfos[i] = info + return obs, rews, terms, truncs, newinfos + + +class RunningMeanStd(): + # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + def __init__(self, epsilon=1e-4, shape=()): + self.mean = np.zeros(shape, 'float64') + self.var = np.ones(shape, 'float64') + self.count = epsilon + + def update(self, x): + batch_mean = np.mean(x, axis=0) + batch_var = np.var(x, axis=0) + batch_count = x.shape[0] + self.update_from_moments(batch_mean, batch_var, batch_count) + + def update_from_moments(self, batch_mean, batch_var, batch_count): + self.mean, self.var, self.count = update_mean_var_count_from_moments( + self.mean, self.var, self.count, batch_mean, batch_var, batch_count) + + +def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count): + delta = batch_mean - mean + tot_count = count + batch_count + + new_mean = mean + delta * batch_count / tot_count + m_a = var * count + m_b = batch_var * batch_count + M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count + new_var = M2 / tot_count + new_count = tot_count + + return new_mean, new_var, new_count diff --git a/examples/utils/vtrace.py b/examples/utils/vtrace.py new file mode 100644 index 00000000..7b02c574 --- /dev/null +++ b/examples/utils/vtrace.py @@ -0,0 +1,138 @@ +# This file taken from +# https://github.com/deepmind/scalable_agent/blob/ +# cd66d00914d56c8ba2f0615d9cdeefcb169a8d70/vtrace.py +# and modified. + +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions to compute V-trace off-policy actor critic targets. + +For details and theory see: + +"IMPALA: Scalable Distributed Deep-RL with +Importance Weighted Actor-Learner Architectures" +by Espeholt, Soyer, Munos et al. + +See https://arxiv.org/abs/1802.01561 for the full paper. +""" + +import collections + +import torch +import torch.nn.functional as F + +VTraceFromLogitsReturns = collections.namedtuple( + "VTraceFromLogitsReturns", + [ + "vs", + "pg_advantages", + "log_rhos", + "behavior_action_log_probs", + "target_action_log_probs", + ], +) + +VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages") + + +def action_log_probs(policy_logits, actions): + return -F.nll_loss( + F.log_softmax(torch.flatten(policy_logits, 0, -2), dim=-1), + torch.flatten(actions), + reduction="none", + ).view_as(actions) + + +def from_logits( + behavior_policy_logits, + target_policy_logits, + actions, + discounts, + rewards, + values, + bootstrap_value, + clip_rho_threshold=1.0, + clip_pg_rho_threshold=1.0, +): + """V-trace for softmax policies.""" + + target_action_log_probs = action_log_probs(target_policy_logits, actions) + behavior_action_log_probs = action_log_probs(behavior_policy_logits, actions) + log_rhos = target_action_log_probs - behavior_action_log_probs + vtrace_returns = from_importance_weights( + log_rhos=log_rhos, + discounts=discounts, + rewards=rewards, + values=values, + bootstrap_value=bootstrap_value, + clip_rho_threshold=clip_rho_threshold, + clip_pg_rho_threshold=clip_pg_rho_threshold, + ) + return VTraceFromLogitsReturns( + log_rhos=log_rhos, + behavior_action_log_probs=behavior_action_log_probs, + target_action_log_probs=target_action_log_probs, + **vtrace_returns._asdict(), + ) + + +@torch.no_grad() +def from_importance_weights( + log_rhos, + discounts, + rewards, + values, + bootstrap_value, + clip_rho_threshold=1.0, + clip_pg_rho_threshold=1.0, +): + """V-trace from log importance weights.""" + with torch.no_grad(): + rhos = torch.exp(log_rhos) + if clip_rho_threshold is not None: + clipped_rhos = torch.clamp(rhos, max=clip_rho_threshold) + else: + clipped_rhos = rhos + + cs = torch.clamp(rhos, max=1.0) + # Append bootstrapped value to get [v1, ..., v_t+1] + values_t_plus_1 = torch.cat( + [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0 + ) + deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values) + + acc = torch.zeros_like(bootstrap_value) + result = [] + for t in range(discounts.shape[0] - 1, -1, -1): + acc = deltas[t] + discounts[t] * cs[t] * acc + result.append(acc) + result.reverse() + vs_minus_v_xs = torch.stack(result) + + # Add V(x_s) to get v_s. + vs = torch.add(vs_minus_v_xs, values) + + # Advantage for policy gradient. + broadcasted_bootstrap_values = torch.ones_like(vs[0]) * bootstrap_value + vs_t_plus_1 = torch.cat( + [vs[1:], broadcasted_bootstrap_values.unsqueeze(0)], dim=0 + ) + if clip_pg_rho_threshold is not None: + clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold) + else: + clipped_pg_rhos = rhos + pg_advantages = clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values) + + # Make sure no gradients backpropagated through the returned values. + return VTraceReturns(vs=vs, pg_advantages=pg_advantages) diff --git a/task_space/__init__.py b/task_space/__init__.py new file mode 100644 index 00000000..1561be1c --- /dev/null +++ b/task_space/__init__.py @@ -0,0 +1 @@ +from .task_space import TaskSpace diff --git a/task_space/task_space.py b/task_space/task_space.py new file mode 100644 index 00000000..1ef674be --- /dev/null +++ b/task_space/task_space.py @@ -0,0 +1,234 @@ +import itertools +from typing import Any, List, Union + +import numpy as np +from gymnasium.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple + + +class TaskSpace(): + def __init__(self, gym_space: Union[Space, int], tasks=None): + + if not isinstance(gym_space, Space): + gym_space = self._create_gym_space(gym_space) + + self.gym_space = gym_space + + # Autogenerate task names + if tasks is None: + tasks = self._generate_task_names(gym_space) + + self._tasks = set(tasks) if tasks is not None else None + self._encoder, self._decoder = self._make_task_encoder(gym_space, tasks) + + def _create_gym_space(self, gym_space): + if isinstance(gym_space, int): + # Syntactic sugar for discrete space + gym_space = Discrete(gym_space) + elif isinstance(gym_space, tuple): + # Syntactic sugar for discrete space + gym_space = MultiDiscrete(gym_space) + elif isinstance(gym_space, list): + # Syntactic sugar for tuple space + spaces = [] + for i, value in enumerate(gym_space): + spaces[i] = self._create_gym_space(value) + gym_space = Tuple(spaces) + elif isinstance(gym_space, dict): + # Syntactic sugar for dict space + spaces = {} + for key, value in gym_space.items(): + spaces[key] = self._create_gym_space(value) + gym_space = Dict(spaces) + return gym_space + + def _generate_task_names(self, gym_space): + if isinstance(gym_space, Discrete): + tasks = tuple(range(gym_space.n)) + elif isinstance(gym_space, MultiDiscrete): + tasks = [tuple(range(dim)) for dim in gym_space.nvec] + elif isinstance(gym_space, Tuple): + tasks = [self._generate_task_names(value) for value in gym_space.spaces] + elif isinstance(gym_space, Dict): + tasks = {key: tuple(self._generate_task_names(value)) for key, value in gym_space.spaces.items()} + else: + tasks = None + return tasks + + def _make_task_encoder(self, space, tasks): + if isinstance(space, Discrete): + assert space.n == len(tasks), f"Number of tasks ({space.n}) must match number of discrete options ({len(tasks)})" + self._encode_map = {task: i for i, task in enumerate(tasks)} + self._decode_map = {i: task for i, task in enumerate(tasks)} + encoder = lambda task: self._encode_map[task] if task in self._encode_map else None + decoder = lambda task: self._decode_map[task] if task in self._decode_map else None + + elif isinstance(space, Box): + encoder = lambda task: task if space.contains(np.asarray(task, dtype=space.dtype)) else None + decoder = lambda task: task if space.contains(np.asarray(task, dtype=space.dtype)) else None + elif isinstance(space, Tuple): + + assert len(space.spaces) == len(tasks), f"Number of task ({len(space.spaces)})must match options in Tuple ({len(tasks)})" + results = [list(self._make_task_encoder(s, t)) for (s, t) in zip(space.spaces, tasks)] + encoders = [r[0] for r in results] + decoders = [r[1] for r in results] + encoder = lambda task: [e(t) for e, t in zip(encoders, task)] + decoder = lambda task: [d(t) for d, t in zip(decoders, task)] + + elif isinstance(space, MultiDiscrete): + assert len(space.nvec) == len(tasks), f"Number of steps in a tasks ({len(space.nvec)}) must match number of discrete options ({len(tasks)})" + + combinations = [p for p in itertools.product(*tasks)] + encode_map = {task: i for i, task in enumerate(combinations)} + decode_map = {i: task for i, task in enumerate(combinations)} + + encoder = lambda task: encode_map[task] if task in encode_map else None + decoder = lambda task: decode_map[task] if task in decode_map else None + + elif isinstance(space, Dict): + + def helper(task, spaces, tasks, action="encode"): + # Iteratively encodes or decodes each space in the dictionary + output = {} + if (isinstance(spaces, dict) or isinstance(spaces, Dict)): + for key, value in spaces.items(): + if (isinstance(value, dict) or isinstance(value, Dict)): + temp = helper(task[key], value, tasks[key], action) + output.update({key: temp}) + else: + encoder, decoder = self._make_task_encoder(value, tasks[key]) + output[key] = encoder(task[key]) if action == "encode" else decoder(task[key]) + return output + + encoder = lambda task: helper(task, space.spaces, tasks, "encode") + decoder = lambda task: helper(task, space.spaces, tasks, "decode") + else: + encoder = lambda task: task + decoder = lambda task: task + return encoder, decoder + + def decode(self, encoding): + """Convert the efficient task encoding to a task that can be used by the environment.""" + return self._decoder(encoding) + + def encode(self, task): + """Convert the task to an efficient encoding to speed up multiprocessing.""" + return self._encoder(task) + + def add_task(self, task): + """Add a task to the task space. Only implemented for discrete spaces.""" + if task not in self._tasks: + self._tasks.add(task) + # TODO: Increment task space size + self.gym_space = self.increase_space() + # TODO: Optimize adding tasks + self._encoder, self._decoder = self._make_task_encoder(self.gym_space, self._tasks) + + def _sum_axes(list_or_size: Union[list, int]): + if isinstance(list_or_size, int) or isinstance(list_or_size, np.int64): + return list_or_size + elif isinstance(list_or_size, list) or isinstance(list_or_size, np.ndarray): + return np.prod([TaskSpace._sum_axes(x) for x in list_or_size]) + else: + raise NotImplementedError(f"{type(list_or_size)}") + + def _enumerate_axes(self, list_or_size: Union[np.ndarray, int]): + if isinstance(list_or_size, int) or isinstance(list_or_size, np.int64): + return tuple(range(list_or_size)) + elif isinstance(list_or_size, list) or isinstance(list_or_size, np.ndarray): + return tuple(itertools.product(*[self._enumerate_axes(x) for x in list_or_size])) + else: + raise NotImplementedError(f"{type(list_or_size)}") + + def seed(self, seed): + self.gym_space.seed(seed) + + @property + def tasks(self) -> List[Any]: + # TODO: Can I just use _tasks? + return self._tasks + + def get_tasks(self, gym_space: Space = None, sample_interval: float = None) -> List[tuple]: + """ + Return the full list of discrete tasks in the task_space. + Return a sample of the tasks for continuous spaces if sample_interval is specified. + Can be overridden to exclude invalid tasks within the space. + """ + if gym_space is None: + gym_space = self.gym_space + + if isinstance(gym_space, Discrete): + return list(range(gym_space.n)) + elif isinstance(gym_space, Box): + raise NotImplementedError + elif isinstance(gym_space, Tuple): + return list(itertools.product([self.get_tasks(task_space=s) for s in gym_space.spaces])) + elif isinstance(gym_space, Dict): + return itertools.product([self.get_tasks(task_space=s) for s in gym_space.spaces.values()]) + elif isinstance(gym_space, MultiBinary): + return list(self._enumerate_axes(gym_space.nvec)) + elif isinstance(gym_space, MultiDiscrete): + return list(self._enumerate_axes(gym_space.nvec)) + elif gym_space is None: + return [] + else: + raise NotImplementedError + + @property + def num_tasks(self) -> int: + # TODO: Cache results + return self.count_tasks() + + def count_tasks(self, gym_space: Space = None) -> int: + """ + Return the number of discrete tasks in the task_space. + Returns None for continuous spaces. + Graph space not implemented. + """ + # TODO: Test these implementations + if gym_space is None: + gym_space = self.gym_space + + if isinstance(gym_space, Discrete): + return gym_space.n + elif isinstance(gym_space, Box): + return None + elif isinstance(gym_space, Tuple): + return sum([self.count_tasks(gym_space=s) for s in gym_space.spaces]) + elif isinstance(gym_space, Dict): + return sum([self.count_tasks(gym_space=s) for s in gym_space.spaces.values()]) + elif isinstance(gym_space, MultiBinary): + return TaskSpace._sum_axes(gym_space.nvec) + elif isinstance(gym_space, MultiDiscrete): + return TaskSpace._sum_axes(gym_space.nvec) + elif gym_space is None: + return 0 + else: + raise NotImplementedError(f"Unsupported task space type: {type(gym_space)}") + + def task_name(self, task): + return repr(self.decode(task)) + + def contains(self, task): + return task in self._tasks or self.decode(task) in self._tasks + + def increase_space(self, amount: Union[int, float] = 1): + if isinstance(self.gym_space, Discrete): + assert isinstance(amount, int), f"Discrete task space can only be increased by integer amount. Got {amount} instead." + return Discrete(self.gym_space.n + amount) + + def sample(self): + assert isinstance(self.gym_space, Discrete) or isinstance(self.gym_space, Box) or isinstance(self.gym_space, Dict) or isinstance(self.gym_space, Tuple) + return self.decode(self.gym_space.sample()) + + def list_tasks(self): + return list(self._tasks) + + def box_contains(self, x) -> bool: + """Return boolean specifying if x is a valid member of this space.""" + if not isinstance(x, np.ndarray): + try: + x = np.asarray(x, dtype=self.gym_space.dtype) + except (ValueError, TypeError): + return False + + return not bool(x.shape == self.gym_space.shape and np.any((x < self.gym_space.low) | (x > self.gym_space.high))) diff --git a/task_space/test_task_space.py b/task_space/test_task_space.py new file mode 100644 index 00000000..109d0a7e --- /dev/null +++ b/task_space/test_task_space.py @@ -0,0 +1,182 @@ +import gymnasium as gym +from syllabus.task_space import TaskSpace + +if __name__ == "__main__": + # Discrete Tests + task_space = TaskSpace(gym.spaces.Discrete(3), ["a", "b", "c"]) + + assert task_space.encode("a") == 0, f"Expected 0, got {task_space.encode('a')}" + assert task_space.encode("b") == 1, f"Expected 1, got {task_space.encode('b')}" + assert task_space.encode("c") == 2, f"Expected 2, got {task_space.encode('c')}" + assert task_space.encode("d") is None, f"Expected None, got {task_space.encode('d')}" + + assert task_space.decode(0) == "a", f"Expected a, got {task_space.decode(0)}" + assert task_space.decode(1) == "b", f"Expected b, got {task_space.decode(1)}" + assert task_space.decode(2) == "c", f"Expected c, got {task_space.decode(2)}" + assert task_space.decode(3) is None, f"Expected None, got {task_space.decode(3)}" + print("Discrete tests passed!") + + # MultiDiscrete Tests + task_space = TaskSpace(gym.spaces.MultiDiscrete([3, 2]), [("a", "b", "c"), (1, 0)]) + + assert task_space.encode(('a', 1)) == 0, f"Expected 0, got {task_space.encode(('a', 1))}" + assert task_space.encode(('b', 0)) == 3, f"Expected 3, got {task_space.encode(('b', 0))}" + assert task_space.encode(('c', 1)) == 4, f"Expected 4, got {task_space.encode(('c', 1))}" + + assert task_space.decode(3) == ('b', 0), f"Expected ('b', 0), got {task_space.decode(3)}" + assert task_space.decode(5) == ('c', 0), f"Expected ('c', 0), got {task_space.decode(5)}" + print("MultiDiscrete tests passed!") + + # Box Tests + task_space = TaskSpace(gym.spaces.Box(low=0, high=1, shape=(2,)), [(0, 0), (0, 1), (1, 0), (1, 1)]) + + assert task_space.encode([0.0, 0.0]) == [0.0, 0.0], f"Expected [0.0, 0.0], got {task_space.encode([0.0, 0.0])}" + assert task_space.encode([0.0, 0.1]) == [0.0, 0.1], f"Expected [0.0, 0.1], got {task_space.encode([0.0, 0.1])}" + assert task_space.encode([0.1, 0.1]) == [0.1, 0.1], f"Expected [0.1, 0.1], got {task_space.encode([0.1, 0.1])}" + assert task_space.encode([1.0, 0.1]) == [1.0, 0.1], f"Expected [1.0, 0.1], got {task_space.encode([1.0, 0.1])}" + assert task_space.encode([1.0, 1.0]) == [1.0, 1.0], f"Expected [1.0, 1.0], got {task_space.encode([1.0, 1.0])}" + assert task_space.encode([1.2, 1.0]) is None, f"Expected None, got {task_space.encode([1.2, 1.0])}" + assert task_space.encode([1.0, 1.2]) is None, f"Expected None, got {task_space.encode([1.2, 1.0])}" + assert task_space.encode([-0.1, 1.0]) is None, f"Expected None, got {task_space.encode([1.2, 1.0])}" + + assert task_space.decode([1.0, 1.0]) == [1.0, 1.0], f"Expected [1.0, 1.0], got {task_space.decode([1.0, 1.0])}" + assert task_space.decode([0.1, 0.1]) == [0.1, 0.1], f"Expected [0.1, 0.1], got {task_space.decode([0.1, 0.1])}" + assert task_space.decode([-0.1, 1.0]) is None, f"Expected None, got {task_space.decode([1.2, 1.0])}" + print("Box tests passed!") + + # Tuple Tests + task_spaces = (gym.spaces.MultiDiscrete([3, 2]), gym.spaces.Discrete(3)) + task_names = ((("a", "b", "c"), (1, 0)), ("X", "Y", "Z")) + task_space = TaskSpace(gym.spaces.Tuple(task_spaces), task_names) + + assert task_space.encode((('a', 0), 'Y')) == [1, 1], f"Expected 0, got {task_space.encode((('a', 1),'Y'))}" + assert task_space.decode([0, 1]) == [('a', 1), 'Y'], f"Expected 0, got {task_space.decode([0, 1])}" + print("Tuple tests passed!") + + # Dictionary Tests + task_spaces = gym.spaces.Dict({ + "ext_controller": gym.spaces.MultiDiscrete([5, 2, 2]), + "inner_state": gym.spaces.Dict( + { + "charge": gym.spaces.Discrete(10), + "system_checks": gym.spaces.Tuple((gym.spaces.MultiDiscrete([3, 2]), gym.spaces.Discrete(3))), + "job_status": gym.spaces.Dict( + { + "task": gym.spaces.Discrete(5), + "progress": gym.spaces.Box(low=0, high=1, shape=(2,)), + } + ), + } + ), + }) + task_names = { + "ext_controller": [("a", "b", "c", "d", "e"), (1, 0), ("X", "Y")], + "inner_state": { + "charge": [0, 1, 13, 3, 94, 35, 6, 37, 8, 9], + "system_checks": ((("a", "b", "c"), (1, 0)), ("X", "Y", "Z")), + "job_status": { + "task": ["A", "B", "C", "D", "E"], + "progress": [(0, 0), (0, 1), (1, 0), (1, 1)], + } + } + } + task_space = TaskSpace(task_spaces, task_names) + + test_val = { + "ext_controller": ('b', 1, 'X'), + 'inner_state': { + 'charge': 1, + 'system_checks': [('a', 0), 'Y'], + 'job_status': {'task': 'C', 'progress': [0.0, 0.0]} + } + } + decode_val = { + "ext_controller": 4, + "inner_state": { + "charge": 1, + "system_checks": [1, 1], + "job_status": {"progress": [0.0, 0.0], "task": 2}, + }, + } + + assert task_space.encode(test_val) == decode_val, f"Expected {decode_val}, \n but got {task_space.encode(test_val)}" + assert task_space.decode(decode_val) == test_val, f"Expected {test_val}, \n but got {task_space.decode(decode_val)}" + + test_val_2 = { + "ext_controller": ("e", 1, "Y"), + "inner_state": { + "charge": 37, + "system_checks": [("b", 0), "Z"], + "job_status": {"progress": [0.0, 0.1], "task": "D"}, + }, + } + decode_val_2 = { + "ext_controller": 17, + "inner_state": { + "charge": 7, + "system_checks": [3, 2], + "job_status": {"progress": [0.0, 0.1], "task": 3}, + }, + } + + assert task_space.encode(test_val_2) == decode_val_2, f"Expected {decode_val_2}, \n but got {task_space.encode(test_val_2)}" + assert task_space.decode(decode_val_2) == test_val_2, f"Expected {test_val_2}, \n but got {task_space.decode(decode_val_2)}" + + test_val_3 = { + "ext_controller": ("e", 1, "X"), + "inner_state": { + "charge": 8, + "system_checks": [("c", 0), "X"], + "job_status": {"progress": [0.5, 0.1], "task": "E"}, + }, + } + decode_val_3 = { + "ext_controller": 16, + "inner_state": { + "charge": 8, + "system_checks": [5, 0], + "job_status": {"progress": [0.5, 0.1], "task": 4}, + }, + } + + assert task_space.encode(test_val_3) == decode_val_3, f"Expected {decode_val_3}, \n but got {task_space.encode(test_val_3)}" + assert task_space.decode(decode_val_3) == test_val_3, f"Expected {test_val_3}, \n but got {task_space.decode(decode_val_3)}" + + print("Dictionary tests passed!") + + # Test syntactic sugar + task_space = TaskSpace(3) + assert task_space.encode(0) == 0, f"Expected 0, got {task_space.encode(0)}" + assert task_space.encode(1) == 1, f"Expected 1, got {task_space.encode(1)}" + assert task_space.encode(2) == 2, f"Expected 2, got {task_space.encode(2)}" + assert task_space.encode(3) is None, f"Expected None, got {task_space.encode(3)}" + + task_space = TaskSpace((2, 4)) + assert task_space.encode((0, 0)) == 0, f"Expected 0, got {task_space.encode((0, 0))}" + assert task_space.encode((0, 1)) == 1, f"Expected 1, got {task_space.encode((0, 1))}" + assert task_space.encode((1, 0)) == 4, f"Expected 2, got {task_space.encode((1, 0))}" + assert task_space.encode((3, 3)) is None, f"Expected None, got {task_space.encode((3, 3))}" + + task_space = TaskSpace((2, 4)) + assert task_space.encode((0, 0)) == 0, f"Expected 0, got {task_space.encode((0, 0))}" + assert task_space.encode((0, 1)) == 1, f"Expected 1, got {task_space.encode((0, 1))}" + assert task_space.encode((1, 0)) == 4, f"Expected 2, got {task_space.encode((1, 0))}" + assert task_space.encode((3, 3)) is None, f"Expected None, got {task_space.encode((3, 3))}" + + task_space = TaskSpace({"map": 5, "level": (4, 10), "difficulty": 3}) + + encoding = task_space.encode({"map": 0, "level": (0, 0), "difficulty": 0}) + expected = {"map": 0, "level": 0, "difficulty": 0} + + encoding = task_space.encode({"map": 4, "level": (3, 9), "difficulty": 2}) + expected = {"map": 4, "level": 39, "difficulty": 2} + assert encoding == expected, f"Expected {expected}, got {encoding}" + + encoding = task_space.encode({"map": 2, "level": (2, 0), "difficulty": 1}) + expected = {"map": 2, "level": 20, "difficulty": 1} + assert encoding == expected, f"Expected {expected}, got {encoding}" + + encoding = task_space.encode({"map": 5, "level": (2, 11), "difficulty": -1}) + expected = {"map": None, "level": None, "difficulty": None} + assert encoding == expected, f"Expected {expected}, got {encoding}" + print("All tests passed!") diff --git a/tests/determinism.py b/tests/determinism.py new file mode 100644 index 00000000..af4e8e86 --- /dev/null +++ b/tests/determinism.py @@ -0,0 +1,154 @@ +import numpy as np +from syllabus.tests import evaluate_random_policy + +N_EPISODES = 10 + + +def print_if_verbose(verbose, *args, **kwargs): + if verbose > 0: + print(*args, **kwargs) + + +def compare_obs(obs1, obs2): + if not isinstance(obs1, type(obs2)): + print(f"Type mismatch: {type(obs1)} != {type(obs2)}") + return False + + if isinstance(obs1, dict): + for key in obs1.keys(): + if key not in obs2: + print(f"Key {key} not in obs2") + return False + if not compare_obs(obs1[key], obs2[key]): + return False + return True + if isinstance(obs1, (list, tuple, np.ndarray)): + if len(obs1) != len(obs2): + print(f"Length mismatch: {len(obs1)} != {len(obs2)}") + return False + for i in range(len(obs1)): + if not compare_obs(obs1[i], obs2[i]): + return False + return True + else: + return obs1 == obs2 + + +def compare_episodes(make_env, task1, task2, verbose=0): + term1 = trunc1 = term2 = trunc2 = False + step = 0 + num_obs_failed = 0 + num_rews_failed = 0 + + # Set tasks + env1 = make_env() + env1.reset(new_task=task1) + env2 = make_env() + env2.reset(new_task=task2) + + # Seed spaces + env1.action_space.seed(0) + env1.observation_space.seed(0) + env1.task_space.seed(0) + env2.action_space.seed(0) + env2.observation_space.seed(0) + env2.task_space.seed(0) + + while not (term1 or trunc1 or term2 or trunc2): + action1 = env1.action_space.sample() + action2 = env2.action_space.sample() + + # Check actions + if action1 != action2: + print_if_verbose(verbose, f"Step {step}: Actions are not the same: {action1} != {action2}. Stopping test.") + return False + + obs1, rew1, term1, trunc1, info1 = env1.step(action1) + obs2, rew2, term2, trunc2, info2 = env2.step(action2) + + # Check observations + if not compare_obs(obs1, obs2): + if num_obs_failed == 0: + print_if_verbose(verbose, f"Step {step}: Obs are not the same: {obs1} != {obs2}. This message will not print for future steps.") + num_obs_failed += 1 + + # Check rewards + if rew1 != rew2: + if num_rews_failed == 0: + print_if_verbose(verbose, f"Step {step}: Rewards are not the same: {rew1} != {rew2}. This message will not print for future steps.") + num_rews_failed += 1 + + # Check terms + if term1 != term2: + print_if_verbose(verbose, f"Step {step}: Terms are not the same: {term1} != {term2}. Stopping test.") + return False + + # Check truncs + if trunc1 != trunc2: + print_if_verbose(verbose, f"Step {step}: Truncs are not the same: {trunc1} != {trunc2}. Stopping test.") + return False + + step += 1 + return num_obs_failed == 0 and num_rews_failed == 0 + + +def test_determinism(make_env, verbose=0): + # TODO: Use the task space to sampele seeds/tasks + test_env = make_env() + assert hasattr(test_env, "task_space"), "Environment does not have a task space. make_env must return a TaskEnv or use a TaskWrapper." + task_space = test_env.task_space + + print_if_verbose(verbose, "Runnning determinism tests...") + + # Test full episode returns + print_if_verbose(verbose, "\nTesting average episodic returns...") + seeds = [task_space.sample() for _ in range(N_EPISODES)] + return1, _ = evaluate_random_policy(make_env, num_episodes=N_EPISODES, seeds=seeds) + return2, _ = evaluate_random_policy(make_env, num_episodes=N_EPISODES, seeds=seeds) + full_return_test = return1 == return2 + if full_return_test: + print_if_verbose(verbose, "PASSED: Random policy returns are deterministic!") + else: + print_if_verbose(verbose, f"FAILED: Random policy returns are not deterministic! {return1} != {return2}") + + # Test individual episode returns + print_if_verbose(verbose, "\nTesting individual episode rewards...") + avg_returns, returns = evaluate_random_policy(make_env, num_episodes=N_EPISODES, seeds=[task_space.sample()] * N_EPISODES) + return_test = all(returns == avg_returns) + if return_test: + print_if_verbose(verbose, "PASSED: Episodes returns are deterministic!") + else: + print_if_verbose(verbose, f"FAILED: Episodes returns are not deterministic! {avg_returns} != {returns}") + + print_if_verbose(verbose, "\nTesting different seeds...") + task1 = task2 = task_space.sample() + while task1 == task2: + task2 = task_space.sample() + return1, _ = evaluate_random_policy(make_env, num_episodes=N_EPISODES, seeds=[task1] * N_EPISODES) + return2, _ = evaluate_random_policy(make_env, num_episodes=N_EPISODES, seeds=[task2] * N_EPISODES) + + test3 = return1 != return2 + if test3: + print_if_verbose(verbose, "PASSED: Random policy returns with different seeds are different.") + else: + print_if_verbose(verbose, f"FAILED: Random policy returns with different seeds are the same. {return1} == {return2}") + + print_if_verbose(verbose, "\nTesting actions, rewards, and observations seeds...") + task1 = task2 = task_space.sample() + step_tests_same = compare_episodes(make_env, task1, task2, verbose=verbose) + while task1 == task2: + task2 = task_space.sample() + step_tests_different = compare_episodes(make_env, task1, 2, verbose=0) + + if step_tests_same and not step_tests_different: + print_if_verbose(verbose, "PASSED: Environment returns on individual steps are deterministic with respect to seed.") + elif step_tests_different: + print_if_verbose(verbose, "FAILED: Environment returns on individual steps are deterministic even with different seeds.") + else: + print_if_verbose(verbose, "FAILED: Environment returns on individual steps are not deterministic with the same seed.") + + return { + "avg_episodic_returns": full_return_test, + "episodic_returns": return_test, + "step_values": step_tests_same + } diff --git a/tests/sync_test_curriculum.py b/tests/sync_test_curriculum.py new file mode 100644 index 00000000..45399f29 --- /dev/null +++ b/tests/sync_test_curriculum.py @@ -0,0 +1,51 @@ +import typing +from typing import Any, List, Union + +from syllabus.curricula import SequentialCurriculum + + +class SyncTestCurriculum(SequentialCurriculum): + """ + Base class and API for defining curricula to interface with Gym environments. + """ + REQUIRES_STEP_UPDATES = False + REQUIRES_EPISODE_UPDATES = True + REQUIRES_CENTRAL_UPDATES = False + + def __init__(self, num_envs, num_episodes, *curriculum_args, **curriculum_kwargs): + # Create a manual curriculum with a new task per episode, repeated across all envs + task_list = [f"task {i+1}" for i in range(num_episodes)] + stopping = [f"tasks>={num_envs}"] * (num_episodes - 1) + super().__init__(task_list, stopping, *curriculum_args, **curriculum_kwargs) + self.num_envs = num_envs + self.num_episodes = num_episodes + self.task_counts = {self.task_space.encode(task): 0 for task in task_list} + self.task_counts[0] = 0 # Error task + self.total_reward = 0 + self.total_dones = 0 + + def update_on_episode(self, episode_return, episode_len, episode_task, env_id: int = None) -> None: + super().update_on_episode(episode_return, episode_len, episode_task, env_id) + self.total_reward += episode_return + self.total_dones += 1 + self.task_counts[episode_task] += 1 + + def get_stats(self): + return { + "total_reward": self.total_reward, + "total_dones": self.total_dones, + "task_counts": self.task_counts + } + + def sample(self, k: int = 1) -> Union[List, Any]: + remaining_tasks = (self.num_episodes * self.num_envs) - self.total_tasks + if remaining_tasks < k: + tasks = super().sample(k=remaining_tasks) + [0] * (k - remaining_tasks) + else: + tasks = [] + while k > self.num_envs - self.n_tasks: + tasks += super().sample(k=self.num_envs - self.n_tasks) + k -= (self.num_envs - self.n_tasks) + if k > 0: + tasks += super().sample(k=k) + return tasks diff --git a/tests/sync_test_env.py b/tests/sync_test_env.py new file mode 100644 index 00000000..1c054228 --- /dev/null +++ b/tests/sync_test_env.py @@ -0,0 +1,33 @@ +import warnings +import gymnasium as gym +from syllabus.core import TaskEnv +from syllabus.task_space import TaskSpace + + +class SyncTestEnv(TaskEnv): + def __init__(self, num_episodes, num_steps=100): + super().__init__() + self.num_steps = num_steps + self.action_space = gym.spaces.Discrete(2) + self.observation_space = gym.spaces.Tuple((gym.spaces.Discrete(self.num_steps), gym.spaces.Discrete(2))) + self.task_space = TaskSpace(gym.spaces.Discrete(num_episodes + 1), ["error task"] + [f"task {i+1}" for i in range(num_episodes)]) + self.task = "error_task" + + def reset(self, new_task=None): + if new_task == "error task": + warnings.warn("Received error task. This likely means that too many tasks are being requested.") + if new_task is None: + warnings.warn("No task provided. Resetting to error task.") + self.task = new_task + self._turn = 0 + return (self._turn, None), {"content": "reset", "task": self.task} + + def step(self, action): + self._turn += 1 + + obs = self.observation((self._turn, action)) + rew = 1 + term = self._turn >= self.num_steps + trunc = False + info = {"content": "step", "task_completion": self._task_completion(obs, rew, term, trunc, {})} + return obs, rew, term, trunc, info diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..98bac823 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,278 @@ +import time +import warnings +from multiprocessing import Process + +import gym as openai_gym +import gymnasium as gym +import numpy as np +import ray +import torch +from shimmy.openai_gym_compatibility import GymV21CompatibilityV0 + +from syllabus.core import MultiProcessingSyncWrapper, RaySyncWrapper, ReinitTaskWrapper +from syllabus.examples.task_wrappers.cartpole_task_wrapper import CartPoleTaskWrapper +from syllabus.task_space import TaskSpace +from syllabus.tests import SyncTestEnv + + +def evaluate_random_policy(make_env, num_episodes=100, seeds=None): + env = make_env(seed=seeds[0] if seeds else None) + + # Seed environment + env.action_space.seed(0) + env.observation_space.seed(0) + + episode_returns = [] + + for i in range(num_episodes): + episode_return = 0 + if seeds: + _ = env.reset(new_task=seeds[i]) + env.action_space.seed(0) + env.observation_space.seed(0) + else: + _ = env.reset() + term = trunc = False + while not (term or trunc): + action = env.action_space.sample() + _, rew, term, trunc, _ = env.step(action) + episode_return += rew + episode_returns.append(episode_return) + + avg_return = sum(episode_returns) / len(episode_returns) + # print(f"Average Episodic Return: {avg_return}") + return avg_return, episode_returns + + +def run_episode(env, new_task=None, curriculum=None, env_id=0): + """Run a single episode of the environment.""" + if new_task is not None: + obs = env.reset(new_task=new_task) + else: + obs = env.reset() + term = trunc = False + ep_rew = 0 + ep_len = 0 + while not (term or trunc): + action = env.action_space.sample() + obs, rew, term, trunc, info = env.step(action) + if curriculum and curriculum.requires_step_updates: + curriculum.update_on_step(env.task_space.encode(env.task), obs, rew, term, trunc, info, env_id=env_id) + curriculum.update_task_progress(env.task_space.encode(env.task), info["task_completion"], env_id=env_id) + ep_rew += rew + ep_len += 1 + if curriculum and curriculum.requires_episode_updates: + curriculum.update_on_episode(ep_rew, ep_len, env.task_space.encode(env.task), env_id=env_id) + return ep_rew + + +def run_set_length(env, curriculum=None, episodes=None, steps=None, env_id=0, env_outputs=None): + """Run environment for a set number of episodes or steps.""" + assert episodes is not None or steps is not None, "Must specify either episodes or steps." + assert episodes is None or steps is None, "Cannot specify both episodes and steps." + total_episodes = episodes if episodes is not None else 2**16 - 1 + total_steps = steps if steps is not None else 2**16 - 1 + n_steps = 0 + n_episodes = 0 + + # Resume stepping from the last observation. + if env_outputs is None: + obs = env.reset(new_task=curriculum.sample()[0] if curriculum else None) + + while n_episodes < total_episodes and n_steps < total_steps: + term = trunc = False + ep_rew = 0 + ep_len = 0 + while not (term or trunc) and n_steps < total_steps: + action = env.action_space.sample() + obs, rew, term, trunc, info = env.step(action) + if curriculum and curriculum.requires_step_updates: + curriculum.update_on_step(env.task_space.encode(env.task), obs, rew, term, trunc, info, env_id=env_id) + curriculum.update_task_progress(env.task_space.encode(env.task), info["task_completion"], env_id=env_id) + ep_rew += rew + ep_len += 1 + n_steps += 1 + if (term or trunc) and curriculum and curriculum.requires_episode_updates: + curriculum.update_on_episode(ep_rew, ep_len, env.task_space.encode(env.task), env_id=env_id) + n_episodes += 1 + obs = env.reset(new_task=curriculum.sample()[0] if curriculum else None) + + return (obs, rew, term, trunc, info) + + +def run_episodes(env_fn, env_args, env_kwargs, curriculum=None, num_episodes=10, env_id=0): + """Run multiple episodes of the environment.""" + env = env_fn(env_args=env_args, env_kwargs=env_kwargs) + ep_rews = [] + for _ in range(num_episodes): + if curriculum: + task = env.task_space.decode(curriculum.sample()[0]) + ep_rews.append(run_episode(env, new_task=task, curriculum=curriculum, env_id=env_id)) + else: + ep_rews.append(run_episode(env)) + env.close() + + +def run_episodes_queue(env_fn, env_args, env_kwargs, curriculum_components, sync=True, num_episodes=10, update_on_step=True, buffer_size=2, env_id=0): + env = env_fn(curriculum_components, env_args=env_args, env_kwargs=env_kwargs, type="queue", update_on_step=update_on_step, buffer_size=buffer_size) if sync else env_fn(env_args=env_args, env_kwargs=env_kwargs) + ep_rews = [] + for _ in range(num_episodes): + ep_rews.append(run_episode(env, env_id=env_id)) + env.close() + + +@ray.remote +def run_episodes_ray(env_fn, env_args, env_kwargs, sync=True, num_episodes=10, update_on_step=True): + env = env_fn(env_args=env_args, env_kwargs=env_kwargs, type="ray", update_on_step=update_on_step) if sync else env_fn(env_args=env_args, env_kwargs=env_kwargs) + ep_rews = [] + for _ in range(num_episodes): + ep_rews.append(run_episode(env)) + env.close() + + +def run_single_process(env_fn, env_args=(), env_kwargs={}, curriculum=None, num_envs=2, num_episodes=10): + start = time.time() + for _ in range(num_envs): + run_episodes(env_fn, env_args, env_kwargs, curriculum=curriculum, num_episodes=num_episodes) + end = time.time() + native_speed = end - start + return native_speed + + +def run_native_multiprocess(env_fn, env_args=(), env_kwargs={}, curriculum=None, num_envs=2, num_episodes=10, update_on_step=True, buffer_size=2): + start = time.time() + # Choose multiprocessing and curriculum methods + if curriculum: + target = run_episodes_queue + args = (env_fn, env_args, env_kwargs, curriculum.get_components(), True, num_episodes, update_on_step and curriculum.curriculum.requires_step_updates, buffer_size) + else: + target = run_episodes + args = (env_fn, env_args, env_kwargs, (), num_episodes) + + # Run episodes + actors = [] + for i in range(num_envs): + nargs = args + (i,) + actors.append(Process(target=target, args=nargs)) + for actor in actors: + actor.start() + for actor in actors: + actor.join() + end = time.time() + native_speed = end - start + + # Stop curriculum to prevent it from slowing down the next test + if curriculum: + curriculum.stop() + return native_speed + + +def run_ray_multiprocess(env_fn, env_args=(), env_kwargs={}, curriculum=None, num_envs=2, num_episodes=10, update_on_step=True): + if curriculum: + target = run_episodes_ray + args = (env_fn, env_args, env_kwargs, True, num_episodes, update_on_step) + else: + target = run_episodes_ray + args = (env_fn, env_args, env_kwargs, False, num_episodes, update_on_step) + + start = time.time() + remotes = [] + for _ in range(num_envs): + remotes.append(target.remote(*args)) + ray.get(remotes) + end = time.time() + ray_speed = end - start + if curriculum: + ray.kill(curriculum.curriculum) + return ray_speed + +def get_test_values(x): + return torch.unsqueeze(torch.Tensor(np.array([0] * len(x))), -1) + + +# Sync Test Environment +def create_synctest_env(*args, type=None, env_args=(), env_kwargs={}, **kwargs): + env = SyncTestEnv(*env_args, **env_kwargs) + if type == "queue": + env = MultiProcessingSyncWrapper(env, *args, task_space=env.task_space, **kwargs) + elif type == "ray": + env = RaySyncWrapper(env, *args, task_space=env.task_space, **kwargs) + return env + + +# Cartpole Tests +def create_cartpole_env(*args, type=None, env_args=(), env_kwargs={}, **kwargs): + env = gym.make("CartPole-v1", **env_kwargs) + env = CartPoleTaskWrapper(env) + + if type == "queue": + env = MultiProcessingSyncWrapper(env, *args, task_space=env.task_space, **kwargs) + elif type == "ray": + env = RaySyncWrapper(env, *args, task_space=env.task_space, **kwargs) + return env + + +# Nethack Tests +def create_nethack_env(*args, type=None, env_args=(), env_kwargs={}, **kwargs): + try: + from nle.env.tasks import NetHackScore + + from syllabus.examples.task_wrappers.nethack_wrappers import \ + NethackTaskWrapper + except ImportError: + warnings.warn("Unable to import nle.") + + env = NetHackScore(*env_args, **env_kwargs) + env = NethackTaskWrapper(env) + + if type == "queue": + env = MultiProcessingSyncWrapper( + env, *args, task_space=env.task_space, **kwargs + ) + elif type == "ray": + env = RaySyncWrapper(env, *args, task_space=env.task_space, **kwargs) + return env + + +# Procgen Tests +def create_procgen_env(*args, type=None, env_args=(), env_kwargs={}, **kwargs): + try: + import procgen + + from syllabus.examples.task_wrappers.procgen_task_wrapper import \ + ProcgenTaskWrapper + except ImportError: + warnings.warn("Unable to import procgen.") + + env = openai_gym.make("procgen-bigfish-v0", *env_args, **env_kwargs) + env = GymV21CompatibilityV0(env=env) + env = ProcgenTaskWrapper(env, "bigfish") + + if type == "queue": + env = MultiProcessingSyncWrapper( + env, *args, task_space=env.task_space, **kwargs + ) + elif type == "ray": + env = RaySyncWrapper(env, *args, task_space=env.task_space, **kwargs) + return env + + +# Minigrid Tests +def create_minigrid_env(*args, type=None, env_args=(), env_kwargs={}, **kwargs): + try: + from gym_minigrid.envs import DoorKeyEnv # noqa: F401 + from gym_minigrid.register import env_list + except ImportError: + warnings.warn("Unable to import gym_minigrid.") + env = gym.make("MiniGrid-DoorKey-5x5-v0", **env_kwargs) + + def create_env(task): + return gym.make(task) + + task_space = TaskSpace(gym.spaces.Discrete(len(env_list)), env_list) + env = ReinitTaskWrapper(env, create_env, task_space=task_space) + if type == "queue": + env = MultiProcessingSyncWrapper(env, *args, task_space=env.task_space, **kwargs) + elif type == "ray": + env = RaySyncWrapper(env, *args, task_space=env.task_space, **kwargs) + return env From d17a6909cbb9dd49825e8e90b6bf0d2add2abe3a Mon Sep 17 00:00:00 2001 From: allisony9 Date: Wed, 24 Apr 2024 13:13:57 -0400 Subject: [PATCH 13/14] access task_id info --- syllabus/core/environment_sync_wrapper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/syllabus/core/environment_sync_wrapper.py b/syllabus/core/environment_sync_wrapper.py index 6edee7cc..106a8793 100644 --- a/syllabus/core/environment_sync_wrapper.py +++ b/syllabus/core/environment_sync_wrapper.py @@ -121,7 +121,9 @@ def step(self, action): "request_sample": True } self.components.put_update([task_update, episode_update]) - + + info["task"] = self.task_space.encode(self.get_task()) + return obs, rew, term, trunc, info def _package_step_updates(self): From 9f68e1bc4254c333cce6255d626992b36d62b323 Mon Sep 17 00:00:00 2001 From: allisony9 Date: Wed, 24 Apr 2024 14:58:01 -0400 Subject: [PATCH 14/14] add info['task'] to reset, api for the normalize func --- syllabus/core/curriculum_base.py | 6 ++++++ syllabus/core/curriculum_sync_wrapper.py | 6 ++++++ syllabus/core/environment_sync_wrapper.py | 2 ++ 3 files changed, 14 insertions(+) diff --git a/syllabus/core/curriculum_base.py b/syllabus/core/curriculum_base.py index 5e75faff..3f31ab95 100644 --- a/syllabus/core/curriculum_base.py +++ b/syllabus/core/curriculum_base.py @@ -120,6 +120,12 @@ def update_on_episode(self, episode_return: float, episode_length: int, episode_ self.stat_recorder.record(episode_return, episode_length, episode_task, env_id) #raise NotImplementedError("Not yet implemented.") + def normalize(self, reward, task): + """ + Normalize reward by task. + """ + return self.stat_recorder.normalize(reward, task) + def update_on_demand(self, metrics: Dict): """Update the curriculum with arbitrary inputs. diff --git a/syllabus/core/curriculum_sync_wrapper.py b/syllabus/core/curriculum_sync_wrapper.py index c1c2fb27..dd86a858 100644 --- a/syllabus/core/curriculum_sync_wrapper.py +++ b/syllabus/core/curriculum_sync_wrapper.py @@ -67,6 +67,9 @@ def update_batch(self, metrics): def add_task(self, task): self.curriculum.add_task(task) + + def normalize(self, rewards, task): + return self.curriculum.normalize(rewards, task) class MultiProcessingComponents: @@ -253,6 +256,9 @@ def add_task(self, task): def get_components(self): return self._components + + def normalize(self, rewards, task): + return super().normalize(rewards, task) def remote_call(func): diff --git a/syllabus/core/environment_sync_wrapper.py b/syllabus/core/environment_sync_wrapper.py index 106a8793..e9a2b355 100644 --- a/syllabus/core/environment_sync_wrapper.py +++ b/syllabus/core/environment_sync_wrapper.py @@ -80,6 +80,8 @@ def reset(self, *args, **kwargs): added_tasks = message["added_tasks"] for add_task in added_tasks: self.env.add_task(add_task) + obs, info = self.env.reset(*args, new_task=next_task, **kwargs) + info["task"] = self.task_space.encode(self.get_task()) return self.env.reset(*args, new_task=next_task, **kwargs) def step(self, action):