diff --git a/metaworld/evaluation.py b/metaworld/evaluation.py index b4328f2a5..e80314204 100644 --- a/metaworld/evaluation.py +++ b/metaworld/evaluation.py @@ -5,6 +5,7 @@ import gymnasium as gym import numpy as np import numpy.typing as npt +from tqdm import tqdm from metaworld.env_dict import ALL_V3_ENVIRONMENTS @@ -63,6 +64,9 @@ def evaluation( task_name: [] for task_name in set(task_names) } + pbar_envs_done = tqdm(total=len(set(task_names)) * + num_episodes, desc="Evaluation") + def eval_done(returns): return all(len(r) >= num_episodes for _, r in returns.items()) @@ -75,6 +79,10 @@ def eval_done(returns): for i, env_ended in enumerate(dones): if env_ended: + task = task_names[i] + current_count = len(episodic_returns[task]) + if current_count < num_episodes: + pbar_envs_done.update(1) episodic_returns[task_names[i]].append( float(infos["final_info"]["episode"]["r"][i]) ) diff --git a/metaworld/policies/sawyer_peg_insertion_side_v3_policy.py b/metaworld/policies/sawyer_peg_insertion_side_v3_policy.py index 3763a4576..902639755 100644 --- a/metaworld/policies/sawyer_peg_insertion_side_v3_policy.py +++ b/metaworld/policies/sawyer_peg_insertion_side_v3_policy.py @@ -44,14 +44,21 @@ def _desired_pos(o_d: dict[str, npt.NDArray[np.float64]]) -> npt.NDArray[Any]: # Z is constant at .16 pos_hole = np.array([-0.35, o_d["goal_pos"][1], 0.16]) + gripper_dist = o_d["gripper_distance_apart"] + + # gets rid of a tiny bump + hole_z_offset = np.array([0.0, 0.0, 0.01]) + if np.linalg.norm(pos_curr[:2] - pos_peg[:2]) > 0.04: + # move to pickup position above peg return pos_peg + np.array([0.0, 0.0, 0.3]) - elif abs(pos_curr[2] - pos_peg[2]) > 0.025: + elif abs(pos_curr[2] - pos_peg[2]) > 0.025 and gripper_dist > 0.5: + # move down to peg only if not holding return pos_peg elif np.linalg.norm(pos_peg[1:] - pos_hole[1:]) > 0.03: - return pos_hole + np.array([0.4, 0.0, 0.0]) + return pos_hole + np.array([0.4, 0.0, 0.0]) + hole_z_offset else: - return pos_hole + return pos_hole + hole_z_offset @staticmethod def _grab_effort(o_d: dict[str, npt.NDArray[np.float64]]) -> float: diff --git a/pyproject.toml b/pyproject.toml index 7ecf513aa..f8931edf7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,12 +28,18 @@ dependencies = [ "mujoco>=3.0.0", "numpy>=1.18", "scipy>=1.4.1", - "imageio" + "imageio", ] [project.optional-dependencies] # Update dependencies in `all` if any are added or removed -testing = ["ipdb", "memory_profiler", "pyquaternion==0.9.5", "pytest>=4.4.0"] +testing = [ + "ipdb", + "memory_profiler", + "pyquaternion==0.9.5", + "pytest>=4.4.0", + "tqdm", +] dev = ["black", "isort", "mypy"] [project.urls] diff --git a/tests/metaworld/test_evaluation.py b/tests/metaworld/test_evaluation.py index f67e587fd..8c97cfd7b 100644 --- a/tests/metaworld/test_evaluation.py +++ b/tests/metaworld/test_evaluation.py @@ -95,14 +95,29 @@ def test_evaluation(): max_episode_steps=max_episode_steps, vector_strategy="async", ) + num_envs = envs.num_envs agent = ScriptedPolicyAgent(envs) mean_success_rate, mean_returns, success_rate_per_task, _ = evaluation.evaluation( agent, envs, num_episodes=num_episodes ) + envs.close() assert isinstance(mean_returns, float) - assert mean_success_rate >= 0.80 - assert len(success_rate_per_task) == envs.num_envs - assert np.all(np.array(list(success_rate_per_task.values())) >= 0.80) + assert len(success_rate_per_task) == num_envs + worst_accepted_fail_rate = 0.8 + failed_envs_names = [] + failed_envs_rates = [] + for task_name, success_rate in success_rate_per_task.items(): + if success_rate < worst_accepted_fail_rate: + failed_envs_names.append(task_name) + failed_envs_rates.append(success_rate) + if len(failed_envs_names) > 0: + print( + f"The following environments failed the success rate threshold of {worst_accepted_fail_rate*100}%:" + ) + for name, rate in zip(failed_envs_names, failed_envs_rates): + print(f"- {name}: {rate*100}% success rate") + assert False, "Some environments did not meet the success rate threshold." + assert mean_success_rate >= worst_accepted_fail_rate # @pytest.mark.skip