Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions furuta/controls/controllers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def __init__(
control_freq: float = 100.0,
t_final: float = 0.5,
u_lim: float = 0.1,
Q: np.ndarray = np.array([10, 50, 1, 1]),
R: np.ndarray = np.array([0.1]),
Q: np.ndarray = np.array([1, 50, 10, 1]),
R: np.ndarray = np.array([1.0]),
S: np.ndarray = np.array([1.0]),
M: int = 10,
):
Expand Down
169 changes: 169 additions & 0 deletions scripts/real_nmpc_swing_up.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import time
from pathlib import Path

import crocoddyl
import numpy as np

import furuta
from furuta.controls.controllers import SwingUpController
from furuta.logger import Loader, SimpleLogger
from furuta.plotter import Plotter
from furuta.robot import Robot, RobotModel
from furuta.state import Signal, State
from furuta.viewer import Viewer3D

DEVICE = "/dev/ttyACM0"

if __name__ == "__main__":
# Init robot
robot = Robot(DEVICE)

# Constants
control_freq = 100.0
dt = 1.0 / control_freq
t_MPC = 0.5 # MPC Time Horizon (s)
t_XP = 3.0 # Sim duration (s)

# Initial state
init_state = np.zeros(4)
# Desired State
x_ref = np.array([0.0, np.pi, 0.0, 0.0])

# Create the controller
model = RobotModel().robot
controller = SwingUpController(model, x_ref, control_freq, t_MPC)

# Create the data logger
file_name = f"{time.strftime('%Y%m%d-%H%M%S')}.mcap"
log_dir = Path(furuta.__path__[0]).parent / "logs" / "xp" / "nmpc_swing_up"
log_dir.mkdir(parents=True, exist_ok=True)
log_path = log_dir / file_name
logger = SimpleLogger(log_path)

# Solve the OCP a first time to get the warm start
controller.compute_command(init_state)
xs = controller.get_trajectoy()
us = controller.get_command()
# Create the robot viewer
viewer = Viewer3D(model)

# Display the solution
viewer.animate(np.arange(0, t_MPC, dt), xs)
crocoddyl.plotOCSolution(xs, us)

# Warm start
x_ws = xs
u_ws = us

# Wait for user input to start the control loop
input("Go?")

# Reset encoders
robot.reset()

t = 0.0
u = 0.0

(
motor_position,
pendulum_position,
motor_velocity,
pendulum_velocity,
timestamp,
motor_command,
) = robot.step(0.0)

t0 = timestamp
x = np.array(
[
motor_position,
pendulum_position,
motor_velocity,
pendulum_velocity,
]
)

while timestamp - t0 < t_XP:
# Update residual ref
controller.control_rate_residual.reference = np.array([u])

# Solve the OCP
start = time.time()
u = controller.compute_command(x, 20, x_ws, u_ws)
stop = time.time()
compute_time = stop - start

(
desired_motor_position,
desired_pendulum_position,
desired_motor_velocity,
desired_pendulum_velocity,
) = controller.get_trajectoy()[1]

# Basic safety
if abs(desired_pendulum_position) > 2 * np.pi:
break
if abs(desired_motor_position) > np.pi:
break
if abs(desired_motor_velocity) > 50.0:
break

(
motor_position,
pendulum_position,
motor_velocity,
pendulum_velocity,
timestamp,
motor_command,
) = robot.step_PID(desired_motor_position, desired_motor_velocity)

# Basic safety
if abs(pendulum_position) > 2 * np.pi:
break
if abs(motor_position) > np.pi:
break
if abs(motor_velocity) > 50.0:
break

state = State(
motor_position=Signal(measured=motor_position, desired=desired_motor_position),
motor_velocity=Signal(measured=motor_velocity, desired=desired_motor_velocity),
pendulum_position=Signal(
measured=pendulum_position, desired=desired_pendulum_position
),
pendulum_velocity=Signal(
measured=pendulum_velocity, desired=desired_pendulum_velocity
),
action=motor_command,
timing=compute_time,
)
logger.update(int((timestamp - t0) * 1e9), state)

# Get the warm start from the controller
x_ws, u_ws = controller.get_warm_start()

# Update state
x = np.array(
[
motor_position,
pendulum_position,
motor_velocity,
pendulum_velocity,
]
)

# Close logger
logger.stop()

# Read log
loader = Loader()
times, states_dict = loader.load(log_path)

# Plot
plotter = Plotter(times, states_dict)
plotter.plot()

# Animate
states = loader.get_state("measured")
viewer.animate(times, states)
viewer.close()