top of page
Search

The Need for Relative Optimizers | Hypothesis on Muon

  • Writer: Ethan Smith
    Ethan Smith
  • Mar 18
  • 11 min read

Updated: Mar 30

 
 

Presently, most optimizers used in deep learning do not explicitly accommodate their updates with respect to the expected range of magnitudes a given parameter would take on. Though this might be a desirable trait.


An nudge of size 0.001 might have a stronger effect to a parameter, in a matrix of similar sized values, that is expected to sit around 0.004 than to a parameter that sits around ~10 (not common for most parameters, but some norm params get large!) Proportionally, 0.004 moving to 0.003, is much larger than 10 moving to 9.99. In other words, as with most things, all is relative.


To observe the relativity, you can try perturbing an existing model's weights in two different ways: multiplication and addition, and see which hurts the model more.



Preconditioning doesn't exactly help

An irony is that optimizers with preconditioning like ADAM would seem to exacerbate the issue of updates neglecting the desirable range of values a parameter might take on. The size of the incoming gradient may be informative of how large a shift in value a parameter must take to have a meaningfully sized adjustment, not to mention larger parameters often yield larger gradients as well, but we attempt to normalize against this. Assuming gradients are arriving with constant norm on average, ADAM's second moment normalization encourages the update for a given dimension to have an expected absolute value of 1. In fact, if we set beta1=beta2, we recover Signum, where:

update = sign(momentum). 

The update is just +1/-1 scaled by the learning rate.


However, in common scenarios, beta1=0.9 and beta2>=0.98 or so. The lag of the second moment (variance) creates a bit more nuance in what is happening, though seemingly, the goal is to balance all dimensions to have a similar sized update, invariant of their parameter size or expected gradient size. The running variance of gradient updates for a given dimension depicts the average magnitudes it takes on, suggesting a typical space for gradients and parameter sizes to play in. So then, because we normalize, a large update is considered relative to this, i.e. our updates could be measured in Standard Deviations. Effectively, post normalization, if an update for a given dimension has a value of -1/+1 then it meets the expected variance exactly. If it is larger, then we have an outlier, perhaps 3 standard deviations. They could also be smaller.


Preconditioning is not an arbitrary design choice. Preconditioned gradient descent has a wealth of literature behind it. This iconic style of diagram illuminates how preconditioning can help us avoid a zig-zag style descent and take longer concerted steps to the minima.

Though it may be possible to have the best of both worlds, of both smoother preconditioned descent and a scaling that allows us to step quickly where parameters are meant to have large values and step more slowly where parameters are meant to have smaller values.


A counterargument is that we might not be “ready” to take such a large step. In the 2d convex bowl case, this isn't too easy to visualize. But might be more relevant in more turbulent high-dimensional loss landscapes with lots of saddle points.



Empirical Weight Distributions

Something that has often frustrated me is that we can see that a distribution of weights between initialization and pretrained, particularly on that heavy tail of large outliers. With this in mind, it feels that we should be able to define better initializations that live closer to where solutions end up. Granted this is easier said than done.


We know from past trained models that some weights will end up quite large, but we can't really know a priori which weights these will be. So how can we update our beliefs here?


Let's look at the distribution of maximum found weights at Xavier initialization vs the trained "meta-llama/Llama-3.2-1B-Instruct"


The orange bins showcase the distribution of found absolute max value for a given parameter matrix at initialization, while blue shows it in the pretrained model. The high density at 1.0 for orange is all of the norm scale parameters that are initialized to 1.0.


What we can see here is that the largest weights can find themselves more than 10x higher from where the range they were initialized at! These are considerably large values in general, with some nearing 3.0. The learning rates we use in best case scenario can take a large amount of steps to reach these values.

Here are some additional plots of means, stds, medians, and min/max ranges.

A fairly large number of matrices still have statistics near initialization but a significant number do not. Particularly it is telling the number of weights that have medians around ~0.5 revealing just how many dimensions have taken on large magnitudes.

Notably, the tick way over by 3.0 comes from the final norm layer which has unusually large weights compared to the rest of the norm layer parameters.


It would seem that chosen global learning rates accommodate the most to the parameters that need the smallest nudges to ensure we have stable optimization. Stable optimization asks that we cater to the weakest link (smallest values) while possibly making for slower optimization than necessary for others. One option is to vary learning rate per parameter, which is sort of what methods like MuP aim to handle.


Though we might like to be able to adapt optimization per dimension.


Knowing that a weight configuration or "fingerprint" like this is common, I think there's a few things we could try to lean into it.

  1. Optimizers that scale their update with respect to the expected range for an element of a parameter. (this post)

  2. Heavy tailed weight initializations.



Relative Optimizers

A few attempts were made to devise an optimizer that follows 1.


The difficulty lies in the fact that we do not know the expected range of a parameter, but we could perhaps infer it from its current value. If a parameter we initialized at 0.001 wants to grow to 10, it can begin to accelerate more and more as it grows.


However, two problems here:

  1. An issue with using the current parameter value as a multiplier is that our updates will infinitely shrink as we approach zero, risking parameters getting "stuck" and impossible for parameters to cross over from positive to negative and vice versa.

  2. Secondly, you risk a weight infinitely growing and destabilizing training by exploding.


To address 2, I was able to find stability by adding heavy weight decay and capping the maximum step size.


There were two versions I tried.

  1. Half-n-Half

    1. At smaller values, the parameter scaled learning rate goes to 0, and the basic learning rate becomes dominant. At large values, the parameter scaled learning rate becomes dominant.

base_update = update * lr
param_scaled_update = update * param.abs() * param_lr
new_update = base_update * 0.5 + param_scaled_update * 0.5
  1. Epsilon

    1. Adds a small value to our parameter scaling to avoid 0.

new_update = update * (param.abs() + param_eps) * param_lr

One could also use the EMA copy of the model's parameters too for the scaling if one is handy.


These are simple drop-in replacements that can be used with any optimizer.


After some tuning of both baseline and Relative ADAM, I found that Relative ADAM can have a pleasant boost, running about 15k steps ahead of schedule on GPT2. However, the gap appears to close, and the length I trained for is hardly telling



Why does Muon Work?

Muon, to me, is an oddball. A full explanation of how it works can be found here, though the short of it is that updates for 2-dimensional parameters are orthogonalized with respect to the 2d matrix. It approximates taking the singular value decomposition of the gradient momentum, rescales all the singular values to 1, and then scales the whole matrix based on the matrix's dimensions.

U, S, V = SVD(Momentum)
update = UV^T

In practice, SVD is very expensive, so we use iterative Newton-Schulz steps to orthogonalize the matrix. You will also see NS described as iterating to find the "nearest orthogonal matrix"


There's a few reasons Muon comes out of left field compared to the familiar optimizer meta


ADAM is like BatchNorm for gradients, Muon is...?

  • Adam can be seen as an analogue to batch normalization, but performed on gradients.

  • Batch normalization takes running average estimates of the mean and variance activations at a given layer. Following, we subtract by the mean and divide by the standard deviation to transform an incoming set of activations to approximately be mean=0 and var=1, whitening activations without considering correlation components. Mean and element-wise variance are each only [dim], making for a fairly cheap operation. To the best of my knowledge, the same momentum is used for both mean and variance estimates

beta = 0.9
act_mean = act_mean * beta + acts * (1-beta)
act_var = act_var * beta + acts^2 * (1-beta)
whitened_acts = (acts - act_mean) / sqrt(act_var)
  • For ADAM, is the diagonal of the Fisher Information matrix, which is an approximation of the Hessian matrix, a measurement of curvature of the loss landscape, estimated by the running variance. One key difference is instead of subtracting the gradient by the mean/momentum, we use that itself as the update.

beta_mean = 0.9
beta_var = 0.99
grad_mean = grad_mean * beta_mean + grads * (1-beta_mean)
grad_var = grad_var * beta_var + grads^2 * (1-beta_var)
whitened_grads = grad_mean / sqrt(grad_var)
  • Shampoo and SOAP are the next level up to ADAM. Similarly, we track running statistics of the gradient like the mean. Though instead of just the diagonal variance, all covariance components are tracked. For a [N, N] gradient (imagine 512x512 parameter), the size of the covariance would be [N^2, N^2], which can be prohibitively large, not to mention this is only considering covariance for a given parameter matrix, and not covariance across the entirety of the model, which is what we'd really like to have. Because of this, we take a Kronecker approximation of the FIM, which results in 2x [N, N] matrices instead of [N^2, N^2].

  • The closest analogy to something that acts on the activations on the forward pass is Cloneofsimo's decorrelation layer. Because the feature dimensions is much smaller than the gradient matrix, a [dim, dim] sized covariance tracker is affordable, though the whitening calculation, requiring torch.eig, itself is a bit expensive, so its performed in chunks. To better understand the relation to BatchNorm, this could also be named BatchNormWithDecorrelation.

  • Muon doesn't exactly have a forward-pass analogue. BatchNorm, LayerNorm, GroupNorm etc. all act with respect to the 1d feature dimension and performs some kind of whitening based on running statistics.

  • Instead Muon orthogonalizes gradients with respect to their 2d structure. This causes the singular vectors to all be orthogonal to one another and to be of equal magnitude (singular value).

Orthogonalization != Whitening / Preconditioning

  • The orthogonalization performed by Muon does not take into account the covariance between individual dimensions.

    • To the best of my knowledge, this inherently requires estimating statistics over multiple gradients. We cannot determine covariance between dimensions within a single sample instantaneously.

  • Instead of per-dimension, it considers the rows of the matrix, and ensures they are all orthogonal to one another, guaranteeing the update is approximately full rank.

  • Orthogonalization and whitening are seen as similar. They both involve have the flavor of making elements of something, orthogonal, but here orthogonalization refers to whole rows whereas whitening refers to ensuring that individual dimensions have no correlation with one another.

  • If preconditioning aims to normalize the curvature of the space, as we saw in the 2d bowl example, I have no idea if what Muon is doing can be considered preconditioning, but I really think its benefits are the result something else entirely.

Stateless and "Instantaneous"

  • Like most optimizers these days, Muon uses momentum.

  • Though to achieve preconditioning, other optimizers keep a running state tracking the second moment/variance of the gradients as a way of empirically estimating curvature.

    • (Though to be honest, even this feels like a hardly reliable way to estimate curvature in a turbulent, changing loss landscape, drawing into question what other optimizers are really doing as well. To me, the only curvature estimates that feels worth trusting and doesn't rely on assumptions of the landscape is directly calculation the Hessian)

  • Muon does not do this, its orthogonalization does not require an accumulated state describing the landscape. It's orthogonalization is "instantaneous" meaning it does not require an accumulating estimate.

Constant Update Size

  • Regardless of the magnitude of the incoming gradient (momentum), the update produced after performing Newton-Schulz will always be of the same norm. A direct byproduct of making a matrix orthonormal and also the scaling we do the matrix based on its size.

  • So, whether we are in steeper or flatter parts of the loss landscape, we take the same size step regardless.

  • At the very least, this is troubling near convergence. Assuming your gradient norms are dropping towards zero because we are reaching the basin/solution of our optimization problem, the update still ends up being rescaled to the same step size. So even if we pull a gradient of 1e-16, because we are simply at our solution and have no more descent to do, this will be rescaled to be a larger step, overshooting our objective.

  • In practice, learning rate scheduling can help control for this. Not to mention, in legit deep learning problems, zero gradient absolute convergence isn't much seen anyway, it is often more common to "orbit" around a solution, a "stochastic" kind of convergence where we stay in the same area consistently, noise but no drift (citation needed, i forgot the paper).

  • This was something I attempted to address here. We can either 1. rescale the update to match its norm with the original momentum's norm or 2. Keep a running average of gradient norms to match to for some added flexibility while still maintaining some normalization.

Sign Flips (is this a problem?)

  • An interesting behavior of Muon's orthogonalization is that the signs of dimensions commonly flip, while in ADAM-like optimizers, this never happens.

  • This is a bit strange to me, the gradient of our model has told us "move this way to decrease the loss" and we move in the opposite direction? Rescaling is one thing, but flipping the direction is something I don't know how to explain yet.

  • Note: Some friends and I have found the same phenomena occurring with SOAP/PSGD/Shampoo as well, which are all performant optimizers. So its hard to criticize this as a bad thing, but I wish I better understood why this is okay

    • Part of me wonders if anywhere a sign flip occurs we should zero out those dimensions instead.



So with all this in mind, what is Muon actually doing then?


Two theories come to mind, though there are probably many other viable explanations


Muon amplifies gradient noise

  • Singular vectors corresponding to smaller, noisier directions of the gradient that haven't been entirely smoothed out of existence by momentum now have their singular values raise to one, like all other directions.

  • It is a fact that some parts of our gradient are simply just noise, and the rescaling we do treats all directions equally regardless of what is signal and what is noise.

    • In general, I would imagine we would want to average out these to 0 where possible.

  • However, it's possible this is somehow useful.

  • I don't really know how, but I could imagine this being related to escaping saddle points or poor minima more easily.


Muon better respects relativity

  • This one is more related to the topic of this post.

  • Unlike normalization that occurs element-wise, Muon's orthonormalization allows for dimensions with larger gradients to stay large and smaller ones to stay small.

  • Let's see an example of this:


Let's imagine this is our gradient/momentum. Note that row 2 and row 3 are very much correlated to each other, we're trying a "worst case" here to see how things hold up.

Let's also imagine these are pretty typical values for these dimensions, so ADAM's second momentum would bring them all to similar values post-normalization.


If we instead apply Newton-Schulz, we get out this matrix.

It's not entirely perfect, and part of this is also because low-rankness. But larger dimensions do a decent job at staying relatively large compared to other elements in the same row.


If the gradient is a bit closer to orthogonal to begin with:

The relativity holds a bit better


Note the sign flips also.


  • By this, we've almost come full circle back to Signum (momentum SGD) but with a little extra tricks to normalize our updates,

    • and of course orthogonalize them, which is not entirely explained by this point.

  • This also makes a case for why methods with factorized/compressed second moments like AdaFactor can still match the performance of ADAM quite well.

    • Since relative sizes of each dimension may be better preserved rather than forcing each dimension to the same range.











コメント


bottom of page