Trust Region Policy Optimization: Practical Algorithm
In the last post on TRPO, we proved the Monotonic Improvement Guarantee Theorem (MIGT) and derived a lower bound to the true objective \(\eta(\tilde{\pi})\):
\begin{equation} \label{MIGT} \eta(\tilde{\pi}) \leq L_{\pi}(\tilde{\pi}) - \frac{4 \epsilon \alpha^2 \gamma}{(1-\gamma)^2} \end{equation}
Where \(\epsilon = \max_{s, a} \| A_{\pi}(s,a) \|\), and \(\alpha = \max_s D_{TV}(\pi(.\|s) \| \tilde{\pi}(.\|s))\). This result allows us to design an algorithm that will guarantee monotonic policy improvement for every iteration.
However, as we’ll highlight in the next section, this lower bound, while insightful, isn’t amenable to a practical algorithm. In this post we’ll describe some approximations made by the authors to the original result, and derive a practical RL algorithm. We’ll also dive into some of the mathematical and computational details that were glossed over in the original paper, but which I thought are important for an overall understanding of the TRPO algorithm, its strengths, and its limitations.
Heuristic Approximations for the Surrogate Objective
The first modification to the MIGT that’s presented, is replacing the total variation divergence \(D_{TV}\) in \ref{MIGT} with \(D_{KL},\)the KL divergence. Pinsker’s Inequality gives us the following:
\[D_{TV}(\pi(.\|s) \| \tilde{\pi}(.\|s))^2 \leq D_{KL}(\pi(.\|s) \| \tilde{\pi}(.\|s))\]Since the above holds for every state \(s\), it must also be the case that:
\[\max_s D_{TV}(\pi(.\|s) \| \tilde{\pi}(.\|s))^2 \leq \max_s D_{KL}(\pi(.\|s) \| \tilde{\pi}(.\|s))\]And this gives us the MIGT in terms of \(D_{KL}\):
\begin{equation} \label{MIGT2} \eta(\tilde{\pi}) \leq L_{\pi}(\tilde{\pi}) - C \max_s D_{KL}(\pi(.|s) | \tilde{\pi}(.|s)) \end{equation}
Where \(C = \frac{4 \epsilon \gamma}{(1-\gamma)^2}\).
The KL divergence has some nice properties and is more computationally nice than the total variation divergence. The KL divergence is also closely related to the Fisher Information Matrix (FIM) and Natural Gradients, both of which are established ideas in RL will prove useful, as we will see in subsequent sections.
The issue with the above lower bound is that the penalty coefficient \(C\) will result in very small step sizes, due to the \(\epsilon = \max_{s, a} \| A_{\pi}(s,a) \|\) term, which is based on a very loose, theoretical bound. In practice, such a harsh penalty term is often not warranted, and we can often get away with larger step sizes. A robust way to do this is to impose a constraint on the KL-divergence between the new policy and the old policy, a trust region constraint:
\[\text{maximize}_{\theta} \; L_{\theta_{\text{old}}} (\theta) \\ \text{s.t.} \; \max_s D_{KL}(\theta_{\text{old}} \| \theta) \leq \delta\]Better, but the above optimization problem is still quite intractable to solve, due to the constraints imposed over all states \(s\). A final approximation the authors make is by replacing the \(\max_s D_{KL}\) term with one that is averaged over all states instead:
\[\bar{D^{\rho}}_{KL}(\theta_1, \theta_2) = \mathbb{E}_{s \sim \rho} [ D_{KL}(\pi_{\theta_1}(.\|s) \| \pi_{\theta_2}(.\|s)) ]\]Giving the optimization problem:
\begin{equation} \label{MIGT3} \text{maximize}_{\theta} \; L_{\theta_{\text{old}}} (\theta) \quad \text{s.t.} \; \bar{D^{\rho}}_{KL}(\theta_{\text{old}}, \theta) \leq \delta \end{equation}
The authors have noted that this problem is similar to those proposed in prior work, most notably Natural Policy Gradients, which results will prove useful in the discussions in the following sections.
How to Solve the Trust-Region Constrained Optimization Problem?
The final TRPO algorithm, as outlined in the paper proceeds as follows (copy-pasted verbatim from the paper):
- Use the single path or vine procedures to collect a set of state-action pairs along with Monte Carlo estimates of their \(Q\)-values
- By averaging over samples, construct the estimated objective and constraint in Equation (\ref{MIGT3}).
- Approximately solve this constrained optimization problem to update the policy’s parameter vector \(\theta\). We use the conjugate gradient algorithm followed by a line search, which is altogether only slightly more expensive than computing the gradient itself.
The paper largely glosses over point 3 (even in the Appendix), which can be disorienting to readers new to the topic, as solving the constrained problem in an efficient manner with sample-based minibatch methods is far from trivial.
Therefore, it’s worth walking through it step-by-step, as doing so gives some computational insights as to what exactly TRPO is doing.
A Quadratic Approximation and Fisher Information
The paper mentions constructing the Fisher Information Matrix (FIM), which can be defined as the Hessian of the KL-div term w.r.t the current policy params \(\theta\):
\[A_{ij} = \frac{1}{N} \sum_{n=1}^N \frac{\partial^2}{\partial \theta_i \partial \theta_j} D_{KL}(\theta_{\text{old}} \| \theta)\]To see where this comes from, we’ll need to construct a quadratic approximation of \ref{MIGT3}, to do so, we’ll follow the steps as outlined in the Natural Policy Gradients paper.
Since we’re doing gradient descent, we wish to look for a small \(\Delta\theta\) that minimizes our surrogate function \(L\) subject to the KL divergence constraint:
\[\text{maximize}_{\Delta\theta} \; L_{\theta} (\theta + \Delta\theta) \\ \text{s.t.} \; \bar{D^{\rho}}_{KL}(\theta, \theta + \Delta\theta) \leq \delta\]For small \(\Delta\theta\), we can approximate the KL divergence term by its 2nd order Taylor series:
\[D_{KL}(\theta, \theta + \Delta\theta) \\ \approx D_{KL}(\theta, \theta) + \nabla_{\theta'} D_{KL}(\theta, \theta') \Delta\theta - \frac{1}{2} \Delta\theta^T \nabla_{\theta'}^2 D_{KL}(\theta, \theta') \Delta\theta \\ = 0 + \mathbb{E}_a [\nabla_{\theta'} \log \pi_{\theta}] \Delta\theta - - \mathbb{E}_a [\nabla_{\theta'} \log \pi_{\theta'}] \Delta\theta \\ + \frac{1}{2} \Delta\theta^T \mathbb{E}_a [\nabla_{\theta'}^2 \log \pi_{\theta}] \Delta\theta - \frac{1}{2} \Delta\theta^T \mathbb{E}_a [\nabla_{\theta'}^2 \log \pi_{\theta'}] \Delta\theta \\ = - \mathbb{E}_a [\nabla_{\theta'} \log \pi_{\theta'}] \Delta\theta - \frac{1}{2} \Delta\theta^T \mathbb{E}_a [\nabla_{\theta'}^2 \log \pi_{\theta'}] \Delta\theta \\ = - \mathbb{E}_a [\nabla_{\theta'} \log \pi_{\theta'}] \Delta\theta - \frac{1}{2} \Delta\theta^T \textbf{F} \Delta\theta \\\]Where the expectiations \(\mathbb{E}_a\) are taken w.r.t the policy distribution \(a \sim \pi_{\theta}\). In the second step we evaluated the term \(D_{KL}(\theta, \theta) = 0\), and expanded the differentials of \(D_{KL}(\theta, \theta')\), and in the 3rd step we simply dropped the differential terms that do not depend on \(\theta'\) as they become zero.
The first order term cancels out by linearity of expectations: \(\mathbb{E}_a [\nabla \log \pi_{\theta}] = \mathbb{E}_a [\frac{\nabla \pi_{\theta}}{\pi_{\theta}}] = \nabla \sum_a [\pi_{\theta}(a)] = \nabla 1 = 0\)
Leaving us with the 2nd order term. We can then formulate the optimization \ref{MIGT3} as a Lagragian, and the above quadratic approximation, we get the following Lagrangian to maximize:
\begin{equation} \label{Lagrangian} L_{\theta} (\theta) + \nabla L_{\theta}(\theta) \Delta \theta + \frac{1}{2} \lambda \Delta\theta^T \textbf{F} \Delta\theta \end{equation}
Solving the above equation for \(\Delta \theta\) will give us:
\[\Delta \theta = \frac{2}{\lambda} \textbf{F}^{-1} \nabla L_{\theta}(\theta)\]The resulting expression is interesting to study. We can conclude that the direction of \(\Delta\theta\) does not depend on the constraint value \(\delta\). It’s only the step size \(\frac{2}{\lambda}\) that does. The direction only depends on the gradient term \(\nabla L_{\theta}(\theta)\), and the expected Hessian of the log of the policy: \(\textbf{F} = \mathbb{E}_a [\nabla_{\theta}^2 \log \pi_{\theta}]\).
Let’s look at the Hessian term \(\textbf{F}\) and expand further:
\[\textbf{F} = \mathbb{E}_a [\nabla_{\theta'}^2 \log \pi_{\theta'}] \\ = \mathbb{E}_a \left[ \frac{\nabla_{\theta}^2\pi_{\theta}(a)}{\pi_{\theta}(a)} + \left( \frac{\nabla_{\theta} \pi_{\theta}(a)}{\pi_{\theta}(a)} \right)^2 \right] \\ = \mathbb{E}_a \left[ \frac{\nabla_{\theta}^2\pi_{\theta}(a)}{\pi_{\theta}(a)} + \left( \nabla_{\theta} \log \pi_{\theta}(a) \right)^2 \right] \\ = \mathbb{E}_a \left[ \left( \nabla_{\theta} \log \pi_{\theta}(a) \right)^2 \right] \\\]The first term cancels out, again by linearity of summation: \(\mathbb{E}_a [\frac{\nabla^2 \pi_{\theta}}{\pi_{\theta}}] = \nabla^2 \mathbb{E}_a [\frac{\pi_{\theta}}{\pi_{\theta}}] = \nabla^2 \sum_a [\pi_{\theta}(a)] = \nabla^2 1 = 0\)
What we end up with is what is known as the Fisher Information Matrix:
\[\textbf{F} = \mathbb{E}_a \left[ \left( \nabla_{\theta} \log \pi_{\theta}(a) \right)^2 \right]\]The Fisher Information essentially characterises the curvature of the log-likelihood function around a point \(\theta\). To illustrate this, let’s consider a one-dimensional parametrized policy, and consider how the log-likelihood function \(\text{ll}(\theta)\) could vary as we move away from \(\theta\)
\[\text{ll}(\theta') = \mathbb{E}_{a \sim \pi_{\theta}(a)} [ \log \pi_{\theta'}(a) ]\]The 2nd derivative of this function is equivalent to the Fisher Information, as we have shown above. The 2nd derivative characterizes the curvature of the function maxima around \(\theta\). For a log-likelihood function that has a sharp maxima around \(\theta\), the Fisher Information is a very large scalar, while for a log-likelihood function with a flat maximum at \(\theta\), the Fisher Information is smaller in magnitude.
Figure illustrating a log-likehood function with a sharp and flat maxima. Notice how for roughly the same (small) decrease in the log-likelihood function, the step size of the distribution with the large FIM value is much smaller than that of the distribution with the smaller FIM value
The natural gradient scales the original gradient \(\nabla L_{\theta}(\theta)\) by the inverse of the Fisher Information \(\textbf{F}^{-1}\). This has the effect of reducing the effective magnitude of the gradient step when the policy function changes very quickly when we change \(\theta\). This makes sense since if we change the policy too much around the current point / state, we may adversely affect the performance of the policy in other states.
Another way to think about this scaling is instead of trying to compute a fixed size step in the parameter space \(\theta\) that minimizes the objective function, we are trying to compute a fixed size step in the probability space \(\pi\), the size of the step being specified in terms of the KL-divergence, and limited to \(\delta\)
Efficient Optimization
Notice that in order to compute the search direction \(\Delta \theta\), it seems that we may need to compute (and invert) the FIM \(\textbf{F}\). This operation typically scales like \(O(M^3)\), where \(M\) is the total number of model parameters (ie the length of the parameter vector \(\theta\)). For neural nets with millions of parameters, this is usually not computationally tractable.
Fortunately, one may be able to compute this quantity without inverting the matrix, using conjugate gradient descent. (The details of the conjugate gradient descent algorithm is outside the scope of this post, here a high-level description of it will suffice) CG allows us solve equations of the form \(\textbf{Fx} = \textbf{g}\) without inverting the matrix \(F\), or even forming the full matrix \(\textbf{F}\) (which is itself \(O(M^2)\)). We only need access to a function that computes the matrix vector product \(y \rightarrow \textbf{F} y\). Fortunately for us, this can also be done efficiently, with computational cost far less than \(O(M^2)\).
In order to do this, we need an efficient way to compute the product of the FIM with some vector \(\textbf{F} y\). One potential way to do this is to use the Fisher Information Matrix directly:
\[\textbf{F}y = \mathbb{E}_a \left[ \left( \nabla_{\theta} \log \pi_{\theta}(a) \right)^2 \right] y \\ = \sum_j \mathbb{E}_{a \sim \pi_{\theta_{\text{old}}}} \left[ \frac{\partial}{\partial \theta_i} \log \pi_{\theta}(a) \frac{\partial}{\partial \theta_j} \log \pi_{\theta}(a) \right] y_j\]The expectation, taken w.r.t to the old policy \(\pi_{\theta_{\text{old}}}\) can be computed by storing the policy gradients for the trajectories sampled using this old policy.
\[\textbf{F}y = \sum_j \sum_n \left[ \frac{\partial}{\partial \theta_i} \log \pi_{\theta}(a_n \| s_n) \frac{\partial}{\partial \theta_j} \log \pi_{\theta}(a_n \| s_n) \right] y_j\]This can be computationally costly, however. The paper proceed to suggest an alternative formulation that computes this expectation analytically. First, note that the KL divergence can be written:
\[D_{KL}(\pi_{\theta_{\text{old}}}( \cdot | s ), \pi_{\theta}( \cdot | s )) = \text{kl}(\mu_{\theta}(s), \mu_{\text{old}}(s))\]The parametrized policy maps from the input state \(s\) to the distribution parameter vector \(\mu_{\theta}(s)\), which parametrizes the output distribution \(\pi_{\theta}( \cdot | s )\). For a multinomial categorical distribution, this could be the class probabilities, for a gaussian output distribution, this could be the means and the variances. \(\text{kl}\) is the KL divergence between the distributions corresponding to the two distribution parameter vectors.
Differentiating twice w.r.t \(\theta\), we get:
\[\frac{\partial \mu_{\theta}(s)_a}{\partial \theta_i} \frac{\partial \mu_{\theta}(s)_b}{\partial \theta_j} \text{kl}''_{ab}(\mu_{\theta}(s), \mu_{\text{old}}(s)) + \frac{\partial^2 \mu_{\theta}(s)_a}{\partial \theta_i \partial \theta_j} \text{kl}'_{ab}(\mu_{\theta}(s), \mu_{\text{old}}(s))\]Where the primes (\('\)) indicate differentiation w.r.t the first argument, and there is an implied summation over indices \(a, b\). Further, the second term vanishes, which leaves us with just the first term. We denote \(J = \frac{\partial \mu_{\theta}(s)_a}{\partial \theta_i}\) to be the Jacobian of the policy, and \(M\) to be \(\text{kl}''_{ab}(\mu_{\theta}(s), \mu_{\text{old}}(s))\), the Fisher Information Matrix in terms of the distribution parameters \(\mu\). This has a simple, analytic form for many distributions of interest.
The FIM-product can thuse be written
\[\textbf{F}y = J^T M J y\]Which is computationally cheap to perform (scales linearly with the number of policy parameters), and does not require storage of sample estimates of the Jacobian.
Having computed the search direction \(\Delta \theta =\textbf{F}^{-1}\textbf{g}\) where we have denoted \(\textbf{g} = \nabla L_{\theta}(\theta)\), we need to compute the optimal step size. We can start by computing the maximal step length \(\beta\) so that \(\theta + \beta \Delta \theta\) satisfies \ref{MIGT3} at equality (remember that since the direction of the step does not depend on \(\delta\), so it does not change whether the constraint is an equality or inequality):
\begin{equation} \label{MIGT4} \text{maximize}_{\theta} \;L_{\theta_{\text{old}}} (\theta) \quad \text{s.t.} \; \bar{D^{\rho}}_{KL}(\theta_{\text{old}}, \theta) = \delta \end{equation}
From this we obtain:
\[\delta = \bar{D^{\rho}}_{KL} \approx \frac{1}{2} (\beta \Delta \theta)^T \textbf{F} (\beta \Delta \theta) \\ \beta = \sqrt{ \frac{2\delta}{ (\Delta \theta)^T \textbf{F} (\Delta \theta)} }\]A line search is performed on the objective \(L_{\theta_{\text{old}}} - \chi [\bar{D^{\rho}}\_{KL}(\theta\_{\text{old}}, \theta) \leq \delta]\) , where \(\chi[...]\) equals zero when the argument is true, and \(+\infty\) when it is false. The search shrinks \(\beta\) exponentially until the objective improves. The authors note that if this is not done, the algorithm will occasionally compute large steps that cause a large degradation of performance.
Conclusion
And that’s the TRPO algorithm in all its gory detail. The main takeaways here are:
- The lower bound in the original MIGT theorem is both computationally intractable and converges very slowly
- Approximation heuristics such as imposing a KL-divergence constraint and taking the average instead of the max of the KL divergence gives a more tractable optimization objective \ref{MIGT3}
- To solve the constrained optimization problem \ref{MIGT3} we use a quadratic approximation to \ref{MIGT3}, using this, we find out need to compute the product of the gradient of the objective \(L_{\theta}({\theta})\) with the Fisher Information Matrix of the policy distributions \(\pi_{\theta}\)
- The CG algorithm and an analytic form of the FIM makes computing this product much more computationally efficient.