Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions metaworld/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import gymnasium as gym
import numpy as np
import numpy.typing as npt
from tqdm import tqdm

from metaworld.env_dict import ALL_V3_ENVIRONMENTS

Expand Down Expand Up @@ -63,6 +64,9 @@ def evaluation(
task_name: [] for task_name in set(task_names)
}

pbar_envs_done = tqdm(total=len(set(task_names)) *
num_episodes, desc="Evaluation")

def eval_done(returns):
return all(len(r) >= num_episodes for _, r in returns.items())

Expand All @@ -75,6 +79,10 @@ def eval_done(returns):

for i, env_ended in enumerate(dones):
if env_ended:
task = task_names[i]
current_count = len(episodic_returns[task])
if current_count < num_episodes:
pbar_envs_done.update(1)
episodic_returns[task_names[i]].append(
float(infos["final_info"]["episode"]["r"][i])
)
Expand Down
13 changes: 10 additions & 3 deletions metaworld/policies/sawyer_peg_insertion_side_v3_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,21 @@ def _desired_pos(o_d: dict[str, npt.NDArray[np.float64]]) -> npt.NDArray[Any]:
# Z is constant at .16
pos_hole = np.array([-0.35, o_d["goal_pos"][1], 0.16])

gripper_dist = o_d["gripper_distance_apart"]

# gets rid of a tiny bump
hole_z_offset = np.array([0.0, 0.0, 0.01])

if np.linalg.norm(pos_curr[:2] - pos_peg[:2]) > 0.04:
# move to pickup position above peg
return pos_peg + np.array([0.0, 0.0, 0.3])
elif abs(pos_curr[2] - pos_peg[2]) > 0.025:
elif abs(pos_curr[2] - pos_peg[2]) > 0.025 and gripper_dist > 0.5:
# move down to peg only if not holding
return pos_peg
elif np.linalg.norm(pos_peg[1:] - pos_hole[1:]) > 0.03:
return pos_hole + np.array([0.4, 0.0, 0.0])
return pos_hole + np.array([0.4, 0.0, 0.0]) + hole_z_offset
else:
return pos_hole
return pos_hole + hole_z_offset

@staticmethod
def _grab_effort(o_d: dict[str, npt.NDArray[np.float64]]) -> float:
Expand Down
10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,18 @@ dependencies = [
"mujoco>=3.0.0",
"numpy>=1.18",
"scipy>=1.4.1",
"imageio"
"imageio",
]

[project.optional-dependencies]
# Update dependencies in `all` if any are added or removed
testing = ["ipdb", "memory_profiler", "pyquaternion==0.9.5", "pytest>=4.4.0"]
testing = [
"ipdb",
"memory_profiler",
"pyquaternion==0.9.5",
"pytest>=4.4.0",
"tqdm",
]
dev = ["black", "isort", "mypy"]

[project.urls]
Expand Down
21 changes: 18 additions & 3 deletions tests/metaworld/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,29 @@ def test_evaluation():
max_episode_steps=max_episode_steps,
vector_strategy="async",
)
num_envs = envs.num_envs
agent = ScriptedPolicyAgent(envs)
mean_success_rate, mean_returns, success_rate_per_task, _ = evaluation.evaluation(
agent, envs, num_episodes=num_episodes
)
envs.close()
assert isinstance(mean_returns, float)
assert mean_success_rate >= 0.80
assert len(success_rate_per_task) == envs.num_envs
assert np.all(np.array(list(success_rate_per_task.values())) >= 0.80)
assert len(success_rate_per_task) == num_envs
worst_accepted_fail_rate = 0.8
failed_envs_names = []
failed_envs_rates = []
for task_name, success_rate in success_rate_per_task.items():
if success_rate < worst_accepted_fail_rate:
failed_envs_names.append(task_name)
failed_envs_rates.append(success_rate)
if len(failed_envs_names) > 0:
print(
f"The following environments failed the success rate threshold of {worst_accepted_fail_rate*100}%:"
)
for name, rate in zip(failed_envs_names, failed_envs_rates):
print(f"- {name}: {rate*100}% success rate")
assert False, "Some environments did not meet the success rate threshold."
assert mean_success_rate >= worst_accepted_fail_rate


# @pytest.mark.skip
Expand Down