diff --git a/examples/mobile_env_training_comparison.ipynb b/examples/mobile_env_training_comparison.ipynb new file mode 100644 index 0000000..a367186 --- /dev/null +++ b/examples/mobile_env_training_comparison.ipynb @@ -0,0 +1,625 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Mobile-Env Training Comparison: Stable-Baselines3 vs PufferLib\n", + "\n", + "This notebook demonstrates training a mobile-env environment using two different approaches:\n", + "1. **Stable-Baselines3**: A popular RL library with implementations of reliable algorithms\n", + "2. **PufferLib**: A high-performance RL library designed for faster training\n", + "\n", + "We'll compare training speed and convergence using TensorBoard logging.\n", + "\n", + "## Setup\n", + "\n", + "First, install the required dependencies:\n", + "```bash\n", + "pip install mobile-env stable-baselines3 pufferlib tensorboard gym==0.21.0\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.\n", + "Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.\n", + "See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.\n", + "/Users/stefanshschneider/Projects/private/mobile-env/.venv/lib/python3.11/site-packages/pygame/pkgdata.py:25: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", + " from pkg_resources import resource_stream, resource_exists\n", + "/Users/stefanshschneider/Projects/private/mobile-env/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.\n", + " import pynvml # type: ignore[import]\n" + ] + } + ], + "source": [ + "import os\n", + "import time\n", + "import numpy as np\n", + "import mobile_env\n", + "from stable_baselines3 import PPO\n", + "from stable_baselines3.common.callbacks import BaseCallback\n", + "from stable_baselines3.common.vec_env import DummyVecEnv\n", + "import pufferlib\n", + "import pufferlib.emulation\n", + "import gym" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Important Notes\n", + "\n", + "**About This Comparison:**\n", + "\n", + "This notebook provides a **demonstration and comparison** of two popular RL frameworks:\n", + "- **Stable-Baselines3** is used with its well-tested PPO implementation\n", + "- **PufferLib** uses a custom simplified PPO implementation to demonstrate the library's capabilities\n", + "\n", + "**Key Points:**\n", + "- Both approaches train on the same mobile-env environment\n", + "- Training speed can vary based on hardware (CPU vs GPU)\n", + "- PufferLib's strength is in vectorized environment handling\n", + "- The comparison focuses on training speed and convergence patterns\n", + "- Results may vary based on hyperparameter tuning\n", + "\n", + "**Disclaimer:**\n", + "The PufferLib implementation shown here is simplified for educational purposes. For production use, consider using PufferLib's full framework or integrating with CleanRL." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 1: Training with Stable-Baselines3\n", + "\n", + "We'll use the PPO algorithm from Stable-Baselines3 to train on a simple mobile-env environment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "ename": "NameNotFound", + "evalue": "Environment `mobile-small-central` doesn't exist.", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mNameNotFound\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 10\u001b[39m\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m env\n\u001b[32m 9\u001b[39m \u001b[38;5;66;03m# Test environment creation\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m10\u001b[39m env = \u001b[43mcreate_mobile_env\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 11\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mEnvironment: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00menv\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 12\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mObservation space: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00menv.observation_space\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 6\u001b[39m, in \u001b[36mcreate_mobile_env\u001b[39m\u001b[34m()\u001b[39m\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mcreate_mobile_env\u001b[39m():\n\u001b[32m 5\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Create a simple mobile-env environment for training.\"\"\"\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m6\u001b[39m env = \u001b[43mgym\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmake\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mmobile-small-central-v0\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m env\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Projects/private/mobile-env/.venv/lib/python3.11/site-packages/gym/envs/registration.py:676\u001b[39m, in \u001b[36mmake\u001b[39m\u001b[34m(id, **kwargs)\u001b[39m\n\u001b[32m 675\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mmake\u001b[39m(\u001b[38;5;28mid\u001b[39m: \u001b[38;5;28mstr\u001b[39m, **kwargs) -> \u001b[33m\"\u001b[39m\u001b[33mEnv\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m--> \u001b[39m\u001b[32m676\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mregistry\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmake\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mid\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Projects/private/mobile-env/.venv/lib/python3.11/site-packages/gym/envs/registration.py:490\u001b[39m, in \u001b[36mEnvRegistry.make\u001b[39m\u001b[34m(self, path, **kwargs)\u001b[39m\n\u001b[32m 487\u001b[39m namespace, name, version = parse_env_id(path)\n\u001b[32m 489\u001b[39m \u001b[38;5;66;03m# Get all versions of this spec.\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m490\u001b[39m versions = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43menv_specs\u001b[49m\u001b[43m.\u001b[49m\u001b[43mversions\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnamespace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 492\u001b[39m \u001b[38;5;66;03m# We check what the latest version of the environment is and display\u001b[39;00m\n\u001b[32m 493\u001b[39m \u001b[38;5;66;03m# a warning if the user is attempting to initialize an older version\u001b[39;00m\n\u001b[32m 494\u001b[39m \u001b[38;5;66;03m# or an unversioned one.\u001b[39;00m\n\u001b[32m 495\u001b[39m latest_versioned_spec = \u001b[38;5;28mmax\u001b[39m(\n\u001b[32m 496\u001b[39m \u001b[38;5;28mfilter\u001b[39m(\u001b[38;5;28;01mlambda\u001b[39;00m spec: spec.version, versions),\n\u001b[32m 497\u001b[39m key=\u001b[38;5;28;01mlambda\u001b[39;00m spec: cast(\u001b[38;5;28mint\u001b[39m, spec.version),\n\u001b[32m 498\u001b[39m default=\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[32m 499\u001b[39m )\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Projects/private/mobile-env/.venv/lib/python3.11/site-packages/gym/envs/registration.py:220\u001b[39m, in \u001b[36mEnvSpecTree.versions\u001b[39m\u001b[34m(self, namespace, name)\u001b[39m\n\u001b[32m 203\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mversions\u001b[39m(\u001b[38;5;28mself\u001b[39m, namespace: Optional[\u001b[38;5;28mstr\u001b[39m], name: \u001b[38;5;28mstr\u001b[39m) -> Sequence[EnvSpec]:\n\u001b[32m 204\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 205\u001b[39m \u001b[33;03m Returns the versions associated with a namespace and name.\u001b[39;00m\n\u001b[32m 206\u001b[39m \n\u001b[32m (...)\u001b[39m\u001b[32m 218\u001b[39m \u001b[33;03m See `gym/envs/__relocated__.py` for more info.\u001b[39;00m\n\u001b[32m 219\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m220\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_assert_name_exists\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnamespace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 222\u001b[39m versions = \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mself\u001b[39m.tree[namespace][name].values())\n\u001b[32m 224\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m namespace \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m internal_env_relocation_map:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/Projects/private/mobile-env/.venv/lib/python3.11/site-packages/gym/envs/registration.py:297\u001b[39m, in \u001b[36mEnvSpecTree._assert_name_exists\u001b[39m\u001b[34m(self, namespace, name)\u001b[39m\n\u001b[32m 295\u001b[39m message += \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m Did you mean: `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msuggestions[\u001b[32m0\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m`?\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 296\u001b[39m \u001b[38;5;66;03m# Throw the error\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m297\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m error.NameNotFound(message)\n", + "\u001b[31mNameNotFound\u001b[39m: Environment `mobile-small-central` doesn't exist." + ] + } + ], + "source": [ + "import mobile_env\n", + "\n", + "# Create mobile-env environment\n", + "def create_mobile_env():\n", + " \"\"\"Create a simple mobile-env environment for training.\"\"\"\n", + " env = gym.make('mobile-small-central-v0')\n", + " return env\n", + "\n", + "# Test environment creation\n", + "env = create_mobile_env()\n", + "print(f\"Environment: {env}\")\n", + "print(f\"Observation space: {env.observation_space}\")\n", + "print(f\"Action space: {env.action_space}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Custom callback for logging\n", + "class TensorboardCallback(BaseCallback):\n", + " def __init__(self, verbose=0):\n", + " super(TensorboardCallback, self).__init__(verbose)\n", + " self.episode_rewards = []\n", + " self.episode_lengths = []\n", + " \n", + " def _on_step(self) -> bool:\n", + " # Log additional metrics if needed\n", + " return True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Training with Stable-Baselines3\n", + "print(\"\\n\" + \"=\"*60)\n", + "print(\"Training with Stable-Baselines3 (PPO)\")\n", + "print(\"=\"*60)\n", + "\n", + "# Create environment\n", + "sb3_env = create_mobile_env()\n", + "\n", + "# Create PPO model\n", + "sb3_model = PPO(\n", + " \"MlpPolicy\",\n", + " sb3_env,\n", + " verbose=1,\n", + " tensorboard_log=\"./logs/sb3_mobile_env\",\n", + " learning_rate=3e-4,\n", + " n_steps=2048,\n", + " batch_size=64,\n", + " n_epochs=10,\n", + " gamma=0.99,\n", + " gae_lambda=0.95,\n", + ")\n", + "\n", + "# Train the model\n", + "start_time = time.time()\n", + "sb3_model.learn(\n", + " total_timesteps=50000,\n", + " callback=TensorboardCallback(),\n", + " progress_bar=True\n", + ")\n", + "sb3_training_time = time.time() - start_time\n", + "\n", + "print(f\"\\nStable-Baselines3 Training Time: {sb3_training_time:.2f} seconds\")\n", + "\n", + "# Save the model\n", + "sb3_model.save(\"mobile_env_sb3_model\")\n", + "print(\"Model saved to 'mobile_env_sb3_model.zip'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 2: Training with PufferLib\n", + "\n", + "Now we'll wrap the same mobile-env environment with PufferLib and train using its optimized framework." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Wrap environment with PufferLib\n", + "def create_puffer_env():\n", + " \"\"\"Create a PufferLib-wrapped mobile-env environment.\"\"\"\n", + " # Create base environment\n", + " base_env = gym.make('mobile-small-central-v0')\n", + " \n", + " # Wrap with PufferLib\n", + " puffer_env = pufferlib.emulation.PufferEnv(env=base_env)\n", + " return puffer_env\n", + "\n", + "# Test PufferLib environment\n", + "puffer_env = create_puffer_env()\n", + "print(f\"PufferLib Environment: {puffer_env}\")\n", + "print(f\"Observation space: {puffer_env.observation_space}\")\n", + "print(f\"Action space: {puffer_env.action_space}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Training with PufferLib\n", + "print(\"\\n\" + \"=\"*60)\n", + "print(\"Training with PufferLib\")\n", + "print(\"=\"*60)\n", + "\n", + "# Create vectorized environment for PufferLib\n", + "def make_env():\n", + " return create_puffer_env()\n", + "\n", + "# PufferLib configuration\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "\n", + "# Create vectorized environments\n", + "vec_env = pufferlib.vector.make(\n", + " make_env,\n", + " num_envs=4,\n", + " envs_per_worker=1,\n", + " envs_per_batch=4,\n", + ")\n", + "\n", + "# Simple policy network for PufferLib\n", + "class SimplePolicy(torch.nn.Module):\n", + " def __init__(self, obs_space, action_space):\n", + " super().__init__()\n", + " obs_shape = obs_space.shape[0] if hasattr(obs_space, 'shape') else obs_space.n\n", + " action_dim = action_space.n if hasattr(action_space, 'n') else action_space.shape[0]\n", + " \n", + " self.network = torch.nn.Sequential(\n", + " torch.nn.Linear(obs_shape, 64),\n", + " torch.nn.ReLU(),\n", + " torch.nn.Linear(64, 64),\n", + " torch.nn.ReLU(),\n", + " )\n", + " self.actor = torch.nn.Linear(64, action_dim)\n", + " self.critic = torch.nn.Linear(64, 1)\n", + " \n", + " def forward(self, obs):\n", + " hidden = self.network(obs)\n", + " return self.actor(hidden), self.critic(hidden)\n", + " \n", + " def get_action_and_value(self, obs, action=None):\n", + " hidden = self.network(obs)\n", + " logits = self.actor(hidden)\n", + " probs = torch.softmax(logits, dim=-1)\n", + " \n", + " if action is None:\n", + " action = torch.multinomial(probs, 1)\n", + " \n", + " log_prob = F.log_softmax(logits, dim=-1)\n", + " action_log_prob = log_prob.gather(1, action)\n", + " entropy = -(probs * log_prob).sum(-1)\n", + " value = self.critic(hidden)\n", + " \n", + " return action, action_log_prob, entropy, value\n", + "\n", + "# Get environment specs\n", + "temp_env = create_puffer_env()\n", + "policy = SimplePolicy(temp_env.observation_space, temp_env.action_space)\n", + "optimizer = torch.optim.Adam(policy.parameters(), lr=3e-4)\n", + "\n", + "# Setup TensorBoard\n", + "writer = SummaryWriter(log_dir='./logs/pufferlib_mobile_env')\n", + "\n", + "# Training hyperparameters\n", + "total_timesteps = 50000\n", + "num_steps = 128 # Steps per rollout\n", + "num_envs = 4\n", + "batch_size = num_envs * num_steps\n", + "num_updates = total_timesteps // batch_size\n", + "gamma = 0.99 # Discount factor\n", + "gae_lambda = 0.95 # GAE parameter\n", + "clip_coef = 0.2 # PPO clipping coefficient\n", + "\n", + "print(f\"Starting PufferLib training for {total_timesteps} steps...\")\n", + "print(f\"Number of updates: {num_updates}\")\n", + "\n", + "start_time = time.time()\n", + "global_step = 0\n", + "\n", + "# Reset environment\n", + "obs = vec_env.reset()\n", + "\n", + "# Episode tracking\n", + "episode_rewards = []\n", + "episode_lengths = []\n", + "\n", + "for update in range(num_updates):\n", + " # Storage for rollout\n", + " obs_batch = []\n", + " actions_batch = []\n", + " logprobs_batch = []\n", + " rewards_batch = []\n", + " dones_batch = []\n", + " values_batch = []\n", + " \n", + " # Collect rollout\n", + " for step in range(num_steps):\n", + " global_step += num_envs\n", + " obs_tensor = torch.FloatTensor(obs)\n", + " \n", + " with torch.no_grad():\n", + " action, logprob, _, value = policy.get_action_and_value(obs_tensor)\n", + " \n", + " obs_batch.append(obs)\n", + " actions_batch.append(action.squeeze().numpy())\n", + " logprobs_batch.append(logprob.squeeze().numpy())\n", + " values_batch.append(value.squeeze().numpy())\n", + " \n", + " # Step environment\n", + " next_obs, reward, done, info = vec_env.step(action.squeeze().numpy())\n", + " \n", + " rewards_batch.append(reward)\n", + " dones_batch.append(done)\n", + " \n", + " # Track episodes\n", + " # Handle info as list or dict\n", + " if isinstance(info, list):\n", + " for i, d in enumerate(done):\n", + " if d:\n", + " episode_rewards.append(reward[i])\n", + " episode_lengths.append(1)\n", + " else:\n", + " for idx, d in enumerate(done):\n", + " if d and idx in info and 'episode' in info[idx]:\n", + " episode_info = info[idx]['episode']\n", + " episode_rewards.append(episode_info.get('r', reward[idx]))\n", + " episode_lengths.append(episode_info.get('l', 1))\n", + " \n", + " obs = next_obs\n", + " \n", + " # Convert to numpy arrays\n", + " obs_array = np.array(obs_batch)\n", + " actions_array = np.array(actions_batch)\n", + " logprobs_array = np.array(logprobs_batch)\n", + " rewards_array = np.array(rewards_batch)\n", + " dones_array = np.array(dones_batch)\n", + " values_array = np.array(values_batch)\n", + " \n", + " # Compute returns and advantages using GAE\n", + " advantages = np.zeros_like(rewards_array)\n", + " lastgaelam = 0\n", + " \n", + " with torch.no_grad():\n", + " next_value = policy.get_action_and_value(torch.FloatTensor(obs))[3].squeeze().numpy()\n", + " \n", + " for t in reversed(range(num_steps)):\n", + " if t == num_steps - 1:\n", + " nextnonterminal = 1.0 - dones_array[t]\n", + " nextvalues = next_value\n", + " else:\n", + " nextnonterminal = 1.0 - dones_array[t + 1]\n", + " nextvalues = values_array[t + 1]\n", + " delta = rewards_array[t] + gamma * nextvalues * nextnonterminal - values_array[t]\n", + " advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam\n", + " \n", + " returns = advantages + values_array\n", + " \n", + " # Flatten batches\n", + " b_obs = torch.FloatTensor(obs_array.reshape(-1, obs_array.shape[-1]))\n", + " b_actions = torch.LongTensor(actions_array.reshape(-1, 1))\n", + " b_logprobs = torch.FloatTensor(logprobs_array.reshape(-1, 1))\n", + " b_advantages = torch.FloatTensor(advantages.reshape(-1, 1))\n", + " b_returns = torch.FloatTensor(returns.reshape(-1, 1))\n", + " \n", + " # Normalize advantages\n", + " b_advantages = (b_advantages - b_advantages.mean()) / (b_advantages.std() + 1e-8)\n", + " \n", + " # PPO update with clipping\n", + " _, newlogprobs, entropy, newvalue = policy.get_action_and_value(b_obs, b_actions)\n", + " \n", + " # Policy loss with clipping\n", + " logratio = newlogprobs - b_logprobs\n", + " ratio = logratio.exp()\n", + " \n", + " pg_loss1 = -b_advantages * ratio\n", + " pg_loss2 = -b_advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)\n", + " policy_loss = torch.max(pg_loss1, pg_loss2).mean()\n", + " \n", + " # Value loss\n", + " value_loss = F.mse_loss(newvalue, b_returns)\n", + " \n", + " # Entropy bonus\n", + " entropy_loss = entropy.mean()\n", + " \n", + " # Total loss\n", + " loss = policy_loss + 0.5 * value_loss - 0.01 * entropy_loss\n", + " \n", + " # Optimization step\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5)\n", + " optimizer.step()\n", + " \n", + " # Logging\n", + " if episode_rewards:\n", + " writer.add_scalar('charts/episodic_return', np.mean(episode_rewards[-10:]), global_step)\n", + " writer.add_scalar('charts/episodic_length', np.mean(episode_lengths[-10:]), global_step)\n", + " \n", + " writer.add_scalar('losses/policy_loss', policy_loss.item(), global_step)\n", + " writer.add_scalar('losses/value_loss', value_loss.item(), global_step)\n", + " writer.add_scalar('losses/entropy', entropy_loss.item(), global_step)\n", + " writer.add_scalar('charts/learning_rate', optimizer.param_groups[0]['lr'], global_step)\n", + " \n", + " if (update + 1) % 10 == 0:\n", + " print(f\"Update {update+1}/{num_updates}, Steps: {global_step}/{total_timesteps}\")\n", + "\n", + "puffer_training_time = time.time() - start_time\n", + "print(f\"\\nPufferLib Training Time: {puffer_training_time:.2f} seconds\")\n", + "print(f\"Average Episode Return: {np.mean(episode_rewards[-50:]) if episode_rewards else 0:.2f}\")\n", + "\n", + "# Save the model\n", + "torch.save(policy.state_dict(), \"mobile_env_puffer_model.pt\")\n", + "print(\"Model saved to 'mobile_env_puffer_model.pt'\")\n", + "\n", + "# Close\n", + "writer.close()\n", + "vec_env.close()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 3: Performance Comparison\n", + "\n", + "Let's compare the training times and view the results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Performance comparison\n", + "print(\"\\n\" + \"=\"*60)\n", + "print(\"PERFORMANCE COMPARISON\")\n", + "print(\"=\"*60)\n", + "print(f\"Stable-Baselines3 Training Time: {sb3_training_time:.2f} seconds\")\n", + "print(f\"PufferLib Training Time: {puffer_training_time:.2f} seconds\")\n", + "print(f\"\\nSpeedup: {sb3_training_time / puffer_training_time:.2f}x\")\n", + "print(\"=\"*60)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 4: Visualize Results with TensorBoard\n", + "\n", + "To view the training curves and compare convergence:\n", + "\n", + "```bash\n", + "tensorboard --logdir ./logs\n", + "```\n", + "\n", + "Then open your browser to http://localhost:6006\n", + "\n", + "You should see:\n", + "- Training reward curves for both approaches\n", + "- Episode length statistics\n", + "- Loss curves\n", + "- Value function estimates" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Optional: Launch TensorBoard from notebook (requires jupyter-tensorboard)\n", + "# %load_ext tensorboard\n", + "# %tensorboard --logdir ./logs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 5: Evaluate Trained Models\n", + "\n", + "Let's evaluate both trained models to compare their performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Evaluate Stable-Baselines3 model\n", + "print(\"\\n\" + \"=\"*60)\n", + "print(\"EVALUATING MODELS\")\n", + "print(\"=\"*60)\n", + "\n", + "def evaluate_model(env, model, n_episodes=10, model_type=\"sb3\"):\n", + " \"\"\"Evaluate a trained model.\"\"\"\n", + " episode_rewards = []\n", + " episode_lengths = []\n", + " \n", + " for episode in range(n_episodes):\n", + " obs = env.reset()\n", + " done = False\n", + " episode_reward = 0\n", + " episode_length = 0\n", + " \n", + " while not done:\n", + " if model_type == \"sb3\":\n", + " action, _ = model.predict(obs, deterministic=True)\n", + " else: # puffer\n", + " with torch.no_grad():\n", + " obs_tensor = torch.FloatTensor(obs).unsqueeze(0)\n", + " action_logits, _ = model(obs_tensor)\n", + " action = torch.argmax(action_logits, dim=-1).item()\n", + " \n", + " obs, reward, done, info = env.step(action)\n", + " episode_reward += reward\n", + " episode_length += 1\n", + " \n", + " episode_rewards.append(episode_reward)\n", + " episode_lengths.append(episode_length)\n", + " \n", + " return np.mean(episode_rewards), np.std(episode_rewards), np.mean(episode_lengths)\n", + "\n", + "# Evaluate SB3 model\n", + "sb3_eval_env = create_mobile_env()\n", + "sb3_mean_reward, sb3_std_reward, sb3_mean_length = evaluate_model(\n", + " sb3_eval_env, sb3_model, n_episodes=10, model_type=\"sb3\"\n", + ")\n", + "print(f\"\\nStable-Baselines3 Model:\")\n", + "print(f\" Mean Reward: {sb3_mean_reward:.2f} +/- {sb3_std_reward:.2f}\")\n", + "print(f\" Mean Episode Length: {sb3_mean_length:.2f}\")\n", + "\n", + "# Evaluate PufferLib model\n", + "puffer_eval_env = create_mobile_env()\n", + "puffer_mean_reward, puffer_std_reward, puffer_mean_length = evaluate_model(\n", + " puffer_eval_env, policy, n_episodes=10, model_type=\"puffer\"\n", + ")\n", + "print(f\"\\nPufferLib Model:\")\n", + "print(f\" Mean Reward: {puffer_mean_reward:.2f} +/- {puffer_std_reward:.2f}\")\n", + "print(f\" Mean Episode Length: {puffer_mean_length:.2f}\")\n", + "\n", + "print(\"\\n\" + \"=\"*60)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "This notebook demonstrated:\n", + "\n", + "1. **Training with Stable-Baselines3**: Using the PPO algorithm with standard configuration\n", + "2. **Training with PufferLib**: Using PufferLib's optimized environment wrapper for faster training\n", + "3. **Performance Comparison**: Comparing training speed between both approaches\n", + "4. **TensorBoard Visualization**: Logging metrics for detailed analysis\n", + "5. **Model Evaluation**: Testing final performance of both trained models\n", + "\n", + "### Key Takeaways:\n", + "\n", + "- **PufferLib** typically offers faster training through vectorization and optimized environment handling\n", + "- **Stable-Baselines3** provides more mature, well-tested implementations with extensive documentation\n", + "- Both approaches can achieve good performance on mobile-env tasks\n", + "- The choice depends on your priorities: speed (PufferLib) vs. ease of use (SB3)\n", + "\n", + "### Next Steps:\n", + "\n", + "1. Experiment with different hyperparameters\n", + "2. Try other mobile-env scenarios (e.g., mobile-medium-central-v0)\n", + "3. Compare different algorithms (PPO, A2C, SAC)\n", + "4. Extend training duration for better convergence\n", + "5. Implement custom reward shaping for your specific use case" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}