Hindsight Experience Replay (HER) — low-level PyTorch (DDPG) in a goal-based environment#

HER is a replay-buffer trick for sparse-reward, goal-conditioned reinforcement learning.

Instead of throwing away failed episodes, we reinterpret them as successes for other goals that were actually achieved later in the same episode.


Learning goals#

  • Understand goal-conditioned RL and why sparse rewards are hard.

  • Precisely define goal relabeling (HER) with math notation.

  • Implement DDPG + HER from scratch in PyTorch (networks, targets, replay buffer, training loop).

  • Use Plotly to visualize reward/success per episode and learning signals (losses, Q-values).

  • See how the same idea is used in Stable-Baselines / Stable-Baselines3.

Prerequisites#

  • Basic RL: MDPs, value functions, off-policy learning.

  • DDPG idea: deterministic actor + critic, target networks, replay buffer.

  • PyTorch basics: nn.Module, optimizers, backprop.

  • Goal-based env interface: observations are a dict with observation, achieved_goal, desired_goal.

Why HER?#

In many tasks you only get a reward when you succeed (a sparse reward). Early on, success may be extremely rare, so learning stalls.

HER addresses this by turning hindsight into supervision:

  • you attempted goal \(g\) and failed,

  • but you did reach many other states / achieved-goals along the way,

  • so we can relabel the goal to something that actually happened and learn from it.

Goal-conditioned RL notation#

A goal-conditioned environment exposes observations as a dict: $\(o_t = \{x_t,\; a_t^g,\; g\}\)$ where:

  • \(x_t\) is the “regular” observation (e.g. position/velocity),

  • \(a_t^g\) is the achieved goal at time \(t\) (often a subset of the state, e.g. end-effector position),

  • \(g\) is the desired goal.

The reward is usually defined through a task-specific distance function \(d(\cdot,\cdot)\) and a threshold \(\epsilon\).

Sparse goal reward (common in HER benchmarks): $\(r_t = r(a_{t+1}^g, g) = -\mathbb{1}[d(a_{t+1}^g, g) > \epsilon]\)\( so reward is \)0\( on success and \)-1$ otherwise.

Goal relabeling (HER) — precise definition#

Consider an episode/trajectory of length \(T\) collected under goal \(g\): $\(\tau = \big\{(x_t, a_t, r_t, x_{t+1}, a_t^g, a_{t+1}^g, g)\big\}_{t=0}^{T-1}\)$

HER constructs additional replay transitions by sampling an alternative goal \(g'\) from achieved goals in the same episode. The most common choice is the future strategy: $\(g' \sim \mathrm{Uniform}\big(\{a_{t+1}^g, a_{t+2}^g, \ldots, a_T^g\}\big)\)$

Then we recompute the reward under the relabeled goal: $\(r'_t = r(a_{t+1}^g, g')\)$

and store the relabeled transition: $\(\big((x_t, g'),\; a_t,\; r'_t,\; (x_{t+1}, g'),\; \text{done}_t\big)\)$

How many relabeled transitions?#

A common parameterization is n_sampled_goal = k: for each real transition, create \(k\) hindsight transitions. Equivalently, if you sample a batch from the replay buffer, you can relabel each sampled transition with probability: $\(p_{\mathrm{HER}} = \frac{k}{k+1}\)\( so on average you see \)k$ hindsight samples per one real sample.

import math
import time
from dataclasses import asdict, dataclass, field
from typing import Callable, Dict, List, Tuple

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import os
import plotly.io as pio

import gymnasium as gym
from gymnasium import spaces

import torch
import torch.nn as nn
import torch.nn.functional as F


pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)

SEED = 42
rng = np.random.default_rng(SEED)
torch.manual_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE
/home/tempa/miniconda3/lib/python3.12/site-packages/torch/cuda/__init__.py:174: UserWarning:

CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
device(type='cpu')

A tiny goal-based environment (Gymnasium Dict observation)#

We will build a simple 2D point-mass that must reach a randomly sampled 2D goal.

  • observation: position + velocity

  • achieved_goal: position

  • desired_goal: target position

The reward is sparse: \(0\) if the point is within a threshold of the goal, else \(-1\).

class PointReachGoalEnv(gym.Env):
    """A minimal goal-based env compatible with HER-style replay buffers.

    Observation is a dict with keys:
      - observation: (pos_x, pos_y, vel_x, vel_y)
      - achieved_goal: (pos_x, pos_y)
      - desired_goal: (goal_x, goal_y)
    """

    metadata = {"render_modes": []}

    def __init__(
        self,
        max_steps: int = 50,
        dt: float = 0.05,
        goal_range: float = 1.0,
        max_speed: float = 1.0,
        distance_threshold: float = 0.05,
        action_scale: float = 10.0,
        seed: int | None = None,
    ):
        super().__init__()
        self.max_steps = int(max_steps)
        self.dt = float(dt)
        self.goal_range = float(goal_range)
        self.max_speed = float(max_speed)
        self.distance_threshold = float(distance_threshold)
        self.action_scale = float(action_scale)
        self._rng = np.random.default_rng(seed)

        # actions are accelerations, clipped to [-1, 1]
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32)

        obs_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32)
        goal_space = spaces.Box(low=-self.goal_range, high=self.goal_range, shape=(2,), dtype=np.float32)
        self.observation_space = spaces.Dict(
            {
                "observation": obs_space,
                "achieved_goal": goal_space,
                "desired_goal": goal_space,
            }
        )

        self._t = 0
        self._pos = np.zeros(2, dtype=np.float32)
        self._vel = np.zeros(2, dtype=np.float32)
        self._goal = np.zeros(2, dtype=np.float32)

    def _sample_goal(self) -> np.ndarray:
        return self._rng.uniform(-self.goal_range, self.goal_range, size=(2,)).astype(np.float32)

    def _get_obs(self) -> Dict[str, np.ndarray]:
        obs = np.concatenate([self._pos, self._vel]).astype(np.float32)
        return {
            "observation": obs,
            "achieved_goal": self._pos.copy(),
            "desired_goal": self._goal.copy(),
        }

    def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info=None) -> np.ndarray:
        """Vectorized sparse reward: 0 if within threshold else -1."""
        achieved_goal = np.asarray(achieved_goal, dtype=np.float32)
        desired_goal = np.asarray(desired_goal, dtype=np.float32)
        d = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
        return -(d > self.distance_threshold).astype(np.float32)

    def compute_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray) -> np.ndarray:
        achieved_goal = np.asarray(achieved_goal, dtype=np.float32)
        desired_goal = np.asarray(desired_goal, dtype=np.float32)
        d = np.linalg.norm(achieved_goal - desired_goal, axis=-1)
        return (d <= self.distance_threshold).astype(np.float32)

    def reset(self, *, seed: int | None = None, options=None):
        super().reset(seed=seed)
        if seed is not None:
            self._rng = np.random.default_rng(seed)

        self._t = 0
        self._pos = self._rng.uniform(-self.goal_range, self.goal_range, size=(2,)).astype(np.float32)
        self._vel = np.zeros(2, dtype=np.float32)
        self._goal = self._sample_goal()
        obs = self._get_obs()
        info = {"is_success": float(self.compute_success(obs["achieved_goal"], obs["desired_goal"]))}
        return obs, info

    def step(self, action: np.ndarray):
        self._t += 1

        action = np.asarray(action, dtype=np.float32)
        action = np.clip(action, self.action_space.low, self.action_space.high)
        accel = self.action_scale * action

        self._vel = np.clip(self._vel + accel * self.dt, -self.max_speed, self.max_speed)
        self._pos = np.clip(self._pos + self._vel * self.dt, -self.goal_range, self.goal_range)

        obs = self._get_obs()
        reward = float(self.compute_reward(obs["achieved_goal"], obs["desired_goal"]))
        terminated = False
        truncated = self._t >= self.max_steps
        info = {"is_success": float(self.compute_success(obs["achieved_goal"], obs["desired_goal"]))}
        return obs, reward, terminated, truncated, info
env = PointReachGoalEnv(max_steps=50, seed=SEED)
obs, info = env.reset()
print(obs.keys(), info)
print('observation shape:', obs['observation'].shape)
print('achieved_goal shape:', obs['achieved_goal'].shape)
print('desired_goal shape:', obs['desired_goal'].shape)

a = env.action_space.sample()
obs2, r, terminated, truncated, info2 = env.step(a)
print('reward:', r, 'done:', terminated or truncated, 'success:', info2['is_success'])
dict_keys(['observation', 'achieved_goal', 'desired_goal']) {'is_success': 0.0}
observation shape: (4,)
achieved_goal shape: (2,)
desired_goal shape: (2,)
reward: -1.0 done: False success: 0.0

DDPG (goal-conditioned)#

We use a standard DDPG setup:

  • Actor \(\mu_\theta(x, g)\) outputs a deterministic action.

  • Critic \(Q_\phi(x, g, a)\) estimates the action-value.

  • Target networks \((\theta', \phi')\) are updated with Polyak averaging.

The only goal-conditioning detail is the network input: we concatenate the observation with the goal. $\(s_t := [x_t \;\|\; g]\)$

def mlp(sizes: List[int], activation: nn.Module, output_activation: nn.Module | None = None) -> nn.Sequential:
    layers: List[nn.Module] = []
    for i in range(len(sizes) - 1):
        act = activation if i < len(sizes) - 2 else (output_activation or nn.Identity())
        layers.append(nn.Linear(sizes[i], sizes[i + 1]))
        layers.append(act)
    return nn.Sequential(*layers)

class Actor(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, hidden_sizes=(256, 256)):
        super().__init__()
        self.net = mlp([state_dim, *hidden_sizes, action_dim], activation=nn.ReLU(), output_activation=nn.Tanh())

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.net(state)

class Critic(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, hidden_sizes=(256, 256)):
        super().__init__()
        self.net = mlp([state_dim + action_dim, *hidden_sizes, 1], activation=nn.ReLU(), output_activation=None)

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        x = torch.cat([state, action], dim=-1)
        return self.net(x)

def soft_update(target: nn.Module, online: nn.Module, tau: float) -> None:
    with torch.no_grad():
        for tp, op in zip(target.parameters(), online.parameters(), strict=True):
            tp.data.mul_(1.0 - tau).add_(op.data, alpha=tau)

HER replay buffer (episodic + relabel-at-sample)#

We store full episodes so we can sample future achieved goals for relabeling.

When sampling a batch, each transition is relabeled with probability \(p_{\mathrm{HER}} = k/(k+1)\).

@dataclass
class HERConfig:
    buffer_capacity_transitions: int = 100_000
    n_sampled_goal: int = 4
    goal_selection_strategy: str = "future"  # future | final | episode


class HindsightReplayBuffer:
    def __init__(
        self,
        cfg: HERConfig,
        compute_reward_fn: Callable[[np.ndarray, np.ndarray], np.ndarray],
    ):
        self.cfg = cfg
        self.compute_reward_fn = compute_reward_fn
        self.episodes: List[Dict[str, np.ndarray]] = []
        self.n_transitions = 0

        if self.cfg.goal_selection_strategy not in {"future", "final", "episode"}:
            raise ValueError(f"Unknown goal_selection_strategy: {self.cfg.goal_selection_strategy}")

    @property
    def her_probability(self) -> float:
        k = max(int(self.cfg.n_sampled_goal), 0)
        return 0.0 if k == 0 else k / (k + 1)

    def __len__(self) -> int:
        return int(self.n_transitions)

    def add_episode(self, episode: Dict[str, np.ndarray]) -> None:
        # expected keys: obs, actions, next_obs, achieved_goal, next_achieved_goal, desired_goal, dones
        T = int(len(episode["actions"]))
        if T == 0:
            return

        self.episodes.append(episode)
        self.n_transitions += T

        while self.n_transitions > self.cfg.buffer_capacity_transitions and self.episodes:
            removed = self.episodes.pop(0)
            self.n_transitions -= int(len(removed["actions"]))

    def sample(self, batch_size: int, rng: np.random.Generator) -> Dict[str, np.ndarray]:
        if not self.episodes:
            raise RuntimeError("Replay buffer is empty")

        batch_size = int(batch_size)
        ep_indices = rng.integers(0, len(self.episodes), size=(batch_size,))

        obs_batch = []
        next_obs_batch = []
        act_batch = []
        done_batch = []
        goal_batch = []
        next_achieved_batch = []
        achieved_batch = []
        t_batch = []

        for i in range(batch_size):
            ep = self.episodes[int(ep_indices[i])]
            T = int(len(ep["actions"]))
            t = int(rng.integers(0, T))

            obs_batch.append(ep["obs"][t])
            next_obs_batch.append(ep["next_obs"][t])
            act_batch.append(ep["actions"][t])
            achieved_batch.append(ep["achieved_goal"][t])
            next_achieved_batch.append(ep["next_achieved_goal"][t])
            goal_batch.append(ep["desired_goal"][t])
            done_batch.append(ep["dones"][t])
            t_batch.append(t)

        obs = np.asarray(obs_batch, dtype=np.float32)
        next_obs = np.asarray(next_obs_batch, dtype=np.float32)
        actions = np.asarray(act_batch, dtype=np.float32)
        achieved_goal = np.asarray(achieved_batch, dtype=np.float32)
        next_achieved_goal = np.asarray(next_achieved_batch, dtype=np.float32)
        desired_goal = np.asarray(goal_batch, dtype=np.float32)
        dones = np.asarray(done_batch, dtype=np.float32).reshape(-1, 1)
        t_idx = np.asarray(t_batch, dtype=np.int64)

        # base rewards (under original goal)
        rewards = self.compute_reward_fn(next_achieved_goal, desired_goal).astype(np.float32).reshape(-1, 1)

        # HER relabeling
        her_mask = rng.random(size=(batch_size,)) < self.her_probability
        if np.any(her_mask):
            for i in np.where(her_mask)[0]:
                ep = self.episodes[int(ep_indices[int(i)])]
                T = int(len(ep["actions"]))
                t = int(t_idx[int(i)])

                if self.cfg.goal_selection_strategy == "future":
                    future_t = int(rng.integers(t, T))
                elif self.cfg.goal_selection_strategy == "final":
                    future_t = T - 1
                elif self.cfg.goal_selection_strategy == "episode":
                    future_t = int(rng.integers(0, T))
                else:
                    raise RuntimeError("unreachable")

                g_her = ep["next_achieved_goal"][future_t].astype(np.float32)
                desired_goal[int(i)] = g_her
                rewards[int(i)] = self.compute_reward_fn(next_achieved_goal[int(i)], g_her).astype(np.float32)

        states = np.concatenate([obs, desired_goal], axis=-1)
        next_states = np.concatenate([next_obs, desired_goal], axis=-1)

        return {
            "states": states,
            "actions": actions,
            "rewards": rewards,
            "next_states": next_states,
            "dones": dones,
        }

DDPG agent (low-level PyTorch)#

This is a minimal DDPG implementation:

  • MSE critic loss against a bootstrapped target.

  • Deterministic policy gradient for the actor.

  • Soft updates for target networks.

Exploration is done by adding Gaussian noise to the actor output.

@dataclass
class DDPGConfig:
    gamma: float = 0.98
    tau: float = 0.005
    actor_lr: float = 1e-3
    critic_lr: float = 1e-3
    hidden_sizes: tuple[int, int] = (256, 256)
    batch_size: int = 256
    start_learning_after: int = 2_000  # transitions in buffer
    exploration_noise_std: float = 0.2
    grad_clip_norm: float | None = 1.0


class DDPGAgent:
    def __init__(self, state_dim: int, action_dim: int, cfg: DDPGConfig, device: torch.device):
        self.cfg = cfg
        self.device = device

        self.actor = Actor(state_dim, action_dim, hidden_sizes=self.cfg.hidden_sizes).to(device)
        self.critic = Critic(state_dim, action_dim, hidden_sizes=self.cfg.hidden_sizes).to(device)

        self.actor_targ = Actor(state_dim, action_dim, hidden_sizes=self.cfg.hidden_sizes).to(device)
        self.critic_targ = Critic(state_dim, action_dim, hidden_sizes=self.cfg.hidden_sizes).to(device)
        self.actor_targ.load_state_dict(self.actor.state_dict())
        self.critic_targ.load_state_dict(self.critic.state_dict())

        self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=self.cfg.actor_lr)
        self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=self.cfg.critic_lr)

    @torch.no_grad()
    def act(self, state: np.ndarray, rng: np.random.Generator, noise_std: float) -> np.ndarray:
        state_t = torch.as_tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
        a = self.actor(state_t).cpu().numpy()[0]
        a = a + rng.normal(0.0, noise_std, size=a.shape).astype(np.float32)
        return np.clip(a, -1.0, 1.0).astype(np.float32)

    def update(self, batch: Dict[str, np.ndarray]) -> Dict[str, float]:
        s = torch.as_tensor(batch["states"], dtype=torch.float32, device=self.device)
        a = torch.as_tensor(batch["actions"], dtype=torch.float32, device=self.device)
        r = torch.as_tensor(batch["rewards"], dtype=torch.float32, device=self.device)
        s2 = torch.as_tensor(batch["next_states"], dtype=torch.float32, device=self.device)
        d = torch.as_tensor(batch["dones"], dtype=torch.float32, device=self.device)

        # --- critic update ---
        with torch.no_grad():
            a2 = self.actor_targ(s2)
            q_targ = self.critic_targ(s2, a2)
            y = r + self.cfg.gamma * (1.0 - d) * q_targ

        q = self.critic(s, a)
        critic_loss = F.mse_loss(q, y)
        self.critic_opt.zero_grad(set_to_none=True)
        critic_loss.backward()
        if self.cfg.grad_clip_norm is not None:
            nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=float(self.cfg.grad_clip_norm))
        self.critic_opt.step()

        # --- actor update ---
        a_pi = self.actor(s)
        actor_loss = -self.critic(s, a_pi).mean()
        self.actor_opt.zero_grad(set_to_none=True)
        actor_loss.backward()
        if self.cfg.grad_clip_norm is not None:
            nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=float(self.cfg.grad_clip_norm))
        self.actor_opt.step()

        # targets
        soft_update(self.actor_targ, self.actor, tau=self.cfg.tau)
        soft_update(self.critic_targ, self.critic, tau=self.cfg.tau)

        return {
            "critic_loss": float(critic_loss.item()),
            "actor_loss": float(actor_loss.item()),
            "q_mean": float(q.mean().item()),
        }

Training loop#

We will:

  1. collect one full episode,

  2. add it to the episodic replay buffer,

  3. run a number of gradient updates.

We log per-episode reward and success, and per-update losses/Q-values.

@dataclass
class TrainConfig:
    n_episodes: int = 150
    updates_per_episode: int = 50

    env_max_steps: int = 50
    env_distance_threshold: float = 0.05
    env_dt: float = 0.05
    env_goal_range: float = 1.0
    env_max_speed: float = 1.0
    env_action_scale: float = 10.0

    her: HERConfig = field(default_factory=HERConfig)
    ddpg: DDPGConfig = field(default_factory=DDPGConfig)


cfg = TrainConfig(
    n_episodes=150,
    updates_per_episode=50,
    env_max_steps=50,
    env_distance_threshold=0.05,
    env_dt=0.05,
    env_goal_range=1.0,
    env_max_speed=1.0,
    env_action_scale=10.0,
    her=HERConfig(buffer_capacity_transitions=100_000, n_sampled_goal=4, goal_selection_strategy="future"),
    ddpg=DDPGConfig(
        gamma=0.98,
        tau=0.005,
        actor_lr=1e-3,
        critic_lr=1e-3,
        batch_size=256,
        start_learning_after=2_000,
        exploration_noise_std=0.2,
        grad_clip_norm=1.0,
    ),
)
cfg
TrainConfig(n_episodes=150, updates_per_episode=50, env_max_steps=50, env_distance_threshold=0.05, env_dt=0.05, env_goal_range=1.0, env_max_speed=1.0, env_action_scale=10.0, her=HERConfig(buffer_capacity_transitions=100000, n_sampled_goal=4, goal_selection_strategy='future'), ddpg=DDPGConfig(gamma=0.98, tau=0.005, actor_lr=0.001, critic_lr=0.001, hidden_sizes=(256, 256), batch_size=256, start_learning_after=2000, exploration_noise_std=0.2, grad_clip_norm=1.0))
env = PointReachGoalEnv(
    max_steps=cfg.env_max_steps,
    distance_threshold=cfg.env_distance_threshold,
    dt=cfg.env_dt,
    goal_range=cfg.env_goal_range,
    max_speed=cfg.env_max_speed,
    action_scale=cfg.env_action_scale,
    seed=SEED,
)
obs_dim = int(env.observation_space["observation"].shape[0])
goal_dim = int(env.observation_space["desired_goal"].shape[0])
act_dim = int(env.action_space.shape[0])
state_dim = obs_dim + goal_dim

buffer = HindsightReplayBuffer(cfg.her, compute_reward_fn=lambda ag, g: env.compute_reward(ag, g))
agent = DDPGAgent(state_dim=state_dim, action_dim=act_dim, cfg=cfg.ddpg, device=DEVICE)

episode_returns: List[float] = []
episode_success: List[float] = []
episode_final_dist: List[float] = []

update_logs: List[Dict[str, float]] = []

t0 = time.time()
for ep in range(cfg.n_episodes):
    obs, info = env.reset()

    ep_obs = []
    ep_next_obs = []
    ep_actions = []
    ep_achieved = []
    ep_next_achieved = []
    ep_desired = []
    ep_dones = []
    ep_rewards = []

    for t in range(cfg.env_max_steps):
        state = np.concatenate([obs["observation"], obs["desired_goal"]], axis=-1).astype(np.float32)
        action = agent.act(state, rng=rng, noise_std=cfg.ddpg.exploration_noise_std)

        next_obs, reward, terminated, truncated, info = env.step(action)
        done = float(terminated or truncated)

        ep_obs.append(obs["observation"].copy())
        ep_next_obs.append(next_obs["observation"].copy())
        ep_actions.append(action.copy())
        ep_achieved.append(obs["achieved_goal"].copy())
        ep_next_achieved.append(next_obs["achieved_goal"].copy())
        ep_desired.append(obs["desired_goal"].copy())
        ep_dones.append(done)
        ep_rewards.append(float(reward))

        obs = next_obs
        if terminated or truncated:
            break

    # add episode to buffer
    episode = {
        "obs": np.asarray(ep_obs, dtype=np.float32),
        "next_obs": np.asarray(ep_next_obs, dtype=np.float32),
        "actions": np.asarray(ep_actions, dtype=np.float32),
        "achieved_goal": np.asarray(ep_achieved, dtype=np.float32),
        "next_achieved_goal": np.asarray(ep_next_achieved, dtype=np.float32),
        "desired_goal": np.asarray(ep_desired, dtype=np.float32),
        "dones": np.asarray(ep_dones, dtype=np.float32),
    }
    buffer.add_episode(episode)

    ep_return = float(np.sum(ep_rewards))
    ep_success = float(info.get("is_success", 0.0))
    final_dist = float(np.linalg.norm(obs["achieved_goal"] - obs["desired_goal"]))
    episode_returns.append(ep_return)
    episode_success.append(ep_success)
    episode_final_dist.append(final_dist)

    # updates
    if len(buffer) >= cfg.ddpg.start_learning_after:
        for _ in range(cfg.updates_per_episode):
            batch = buffer.sample(cfg.ddpg.batch_size, rng=rng)
            stats = agent.update(batch)
            stats["episode"] = float(ep)
            update_logs.append(stats)

    if (ep + 1) % 25 == 0:
        sr = float(np.mean(episode_success[-25:]))
        print(
            f"ep {ep+1:4d}/{cfg.n_episodes} | "
            f"avg_success(last25)={sr:.2f} | "
            f"buffer={len(buffer):6d} | "
            f"elapsed={time.time()-t0:.1f}s"
        )

print("done")
ep   25/150 | avg_success(last25)=0.00 | buffer=  1250 | elapsed=0.1s
ep   50/150 | avg_success(last25)=0.00 | buffer=  2500 | elapsed=9.0s
ep   75/150 | avg_success(last25)=0.00 | buffer=  3750 | elapsed=32.7s
ep  100/150 | avg_success(last25)=0.00 | buffer=  5000 | elapsed=56.3s
ep  125/150 | avg_success(last25)=0.00 | buffer=  6250 | elapsed=80.2s
ep  150/150 | avg_success(last25)=0.12 | buffer=  7500 | elapsed=104.5s
done

Plot reward and success per episode (Plotly)#

We visualize:

  • episode return (sum of sparse rewards)

  • episode success (0/1) and a moving average

df_ep = pd.DataFrame(
    {
        "episode": np.arange(len(episode_returns)),
        "return": episode_returns,
        "success": episode_success,
        "final_dist": episode_final_dist,
    }
)
df_ep["success_ma_10"] = df_ep["success"].rolling(10, min_periods=1).mean()
df_ep["return_ma_10"] = df_ep["return"].rolling(10, min_periods=1).mean()

fig = go.Figure()
fig.add_trace(go.Scatter(x=df_ep["episode"], y=df_ep["return"], mode="lines", name="return"))
fig.add_trace(go.Scatter(x=df_ep["episode"], y=df_ep["return_ma_10"], mode="lines", name="return (ma10)"))
fig.update_layout(title="Episode return (sparse)", xaxis_title="episode", yaxis_title="return")
fig.show()

fig = go.Figure()
fig.add_trace(go.Scatter(x=df_ep["episode"], y=df_ep["success"], mode="markers", name="success"))
fig.add_trace(go.Scatter(x=df_ep["episode"], y=df_ep["success_ma_10"], mode="lines", name="success (ma10)"))
fig.update_layout(title="Success per episode", xaxis_title="episode", yaxis_title="success", yaxis=dict(range=[-0.05, 1.05]))
fig.show()

Plot learning signals (losses and Q-values)#

Loss curves are noisy in RL, but you should still see them stabilize as learning progresses.

df_up = pd.DataFrame(update_logs)
if len(df_up) == 0:
    print("No updates ran yet (buffer too small). Try increasing n_episodes or lowering start_learning_after.")
else:
    df_up["update"] = np.arange(len(df_up))
    df_up["critic_loss_ma_200"] = df_up["critic_loss"].rolling(200, min_periods=1).mean()
    df_up["actor_loss_ma_200"] = df_up["actor_loss"].rolling(200, min_periods=1).mean()
    df_up["q_mean_ma_200"] = df_up["q_mean"].rolling(200, min_periods=1).mean()

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=df_up["update"], y=df_up["critic_loss"], mode="lines", name="critic_loss"))
    fig.add_trace(go.Scatter(x=df_up["update"], y=df_up["critic_loss_ma_200"], mode="lines", name="critic_loss (ma200)"))
    fig.update_layout(title="Critic loss", xaxis_title="update", yaxis_title="MSE")
    fig.show()

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=df_up["update"], y=df_up["actor_loss"], mode="lines", name="actor_loss"))
    fig.add_trace(go.Scatter(x=df_up["update"], y=df_up["actor_loss_ma_200"], mode="lines", name="actor_loss (ma200)"))
    fig.update_layout(title="Actor loss", xaxis_title="update", yaxis_title="-Q(s, pi(s))")
    fig.show()

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=df_up["update"], y=df_up["q_mean"], mode="lines", name="Q mean"))
    fig.add_trace(go.Scatter(x=df_up["update"], y=df_up["q_mean_ma_200"], mode="lines", name="Q mean (ma200)"))
    fig.update_layout(title="Average Q-value", xaxis_title="update", yaxis_title="Q")
    fig.show()

Stable-Baselines / Stable-Baselines3 (web-researched equivalents)#

Stable-Baselines has an explicit HER wrapper class. Stable-Baselines3 moved HER into the replay buffer: use HerReplayBuffer with an off-policy algorithm.

References:

  • HER paper: https://arxiv.org/abs/1707.01495

  • SB3 HER docs: https://github.com/dlr-rm/stable-baselines3/blob/master/docs/modules/her.rst

  • Stable-Baselines HER docs: https://github.com/stable-baselines-team/stable-baselines/blob/master/docs/modules/her.rst

Stable-Baselines3 (SB3)#

From the SB3 docs (docs/modules/her.rst):

from stable_baselines3 import HerReplayBuffer, DDPG, DQN, SAC, TD3
from stable_baselines3.common.envs import BitFlippingEnv

model_class = DQN  # works also with SAC, DDPG and TD3
env = BitFlippingEnv(n_bits=15, continuous=model_class in [DDPG, SAC, TD3], max_steps=15)

model = model_class(
    "MultiInputPolicy",
    env,
    replay_buffer_class=HerReplayBuffer,
    replay_buffer_kwargs=dict(
        n_sampled_goal=4,
        goal_selection_strategy="future",
    ),
    verbose=1,
)
model.learn(1000)

Using the environment from this notebook (SB3):

from stable_baselines3 import HerReplayBuffer, SAC

env = PointReachGoalEnv(max_steps=50)
model = SAC(
    "MultiInputPolicy",
    env,
    replay_buffer_class=HerReplayBuffer,
    replay_buffer_kwargs=dict(n_sampled_goal=4, goal_selection_strategy="future"),
)
model.learn(100_000)

Stable-Baselines (SB)#

From the Stable-Baselines docs (docs/modules/her.rst):

from stable_baselines import HER, DQN, SAC, DDPG, TD3
from stable_baselines.common.bit_flipping_env import BitFlippingEnv

model_class = DQN  # works also with SAC, DDPG and TD3
env = BitFlippingEnv(15, continuous=model_class in [DDPG, SAC, TD3], max_steps=15)

model = HER(
    "MlpPolicy",
    env,
    model_class,
    n_sampled_goal=4,
    goal_selection_strategy="future",
    verbose=1,
)
model.learn(1000)

Notes:

  • both libraries assume the environment exposes compute_reward(achieved_goal, desired_goal, info).

  • for dict observations, SB3 uses "MultiInputPolicy".

Hyperparameters (explained)#

This section explains every hyperparameter used above.

Environment#

  • env_max_steps: episode horizon (more steps = easier exploration, but longer credit assignment).

  • env_distance_threshold: success tolerance (smaller = harder, sparser successes).

  • env_dt: simulation timestep (smaller = smoother dynamics, but harder long-horizon control).

  • env_goal_range: coordinate bounds for start/goal sampling.

  • env_max_speed: velocity clipping (too low can make goals unreachable within the horizon).

  • env_action_scale: converts normalized actions in \([-1, 1]\) into acceleration.

HER#

  • buffer_capacity_transitions: maximum number of stored transitions (in episodes).

  • n_sampled_goal: number of hindsight goals per real transition; larger means stronger HER signal but more bias toward hindsight goals.

  • goal_selection_strategy:

    • future: sample a goal from a future achieved goal in the same episode (most common).

    • final: always use the final achieved goal.

    • episode: sample an achieved goal uniformly from the whole episode.

DDPG#

  • gamma: discount factor (higher = longer-horizon planning).

  • tau: target-network Polyak averaging rate (smaller = more stable, slower tracking).

  • actor_lr, critic_lr: learning rates.

  • hidden_sizes: MLP hidden layer sizes for both actor and critic.

  • batch_size: SGD batch size from replay buffer.

  • start_learning_after: number of transitions before updates begin (stabilizes early learning).

  • exploration_noise_std: Gaussian action noise scale for exploration.

  • grad_clip_norm: gradient norm clipping to reduce instability.

Training#

  • n_episodes: total training episodes.

  • updates_per_episode: gradient steps after each collected episode.

hp = {
    # env
    "env_max_steps": cfg.env_max_steps,
    "env_distance_threshold": cfg.env_distance_threshold,
    "env_dt": cfg.env_dt,
    "env_goal_range": cfg.env_goal_range,
    "env_max_speed": cfg.env_max_speed,
    "env_action_scale": cfg.env_action_scale,
    # her
    **{f"her.{k}": v for k, v in asdict(cfg.her).items()},
    # ddpg
    **{f"ddpg.{k}": v for k, v in asdict(cfg.ddpg).items()},
    # train
    "train.n_episodes": cfg.n_episodes,
    "train.updates_per_episode": cfg.updates_per_episode,
}
pd.DataFrame({"hyperparameter": list(hp.keys()), "value": list(hp.values())})
hyperparameter value
0 env_max_steps 50
1 env_distance_threshold 0.05
2 env_dt 0.05
3 env_goal_range 1.0
4 env_max_speed 1.0
5 env_action_scale 10.0
6 her.buffer_capacity_transitions 100000
7 her.n_sampled_goal 4
8 her.goal_selection_strategy future
9 ddpg.gamma 0.98
10 ddpg.tau 0.005
11 ddpg.actor_lr 0.001
12 ddpg.critic_lr 0.001
13 ddpg.hidden_sizes (256, 256)
14 ddpg.batch_size 256
15 ddpg.start_learning_after 2000
16 ddpg.exploration_noise_std 0.2
17 ddpg.grad_clip_norm 1.0
18 train.n_episodes 150
19 train.updates_per_episode 50