Phase transitions in Stochastic Gradient Descent, from a High-dimensional Data Visualisation Perspective
R
high-dimensional data
deep learning
data visualisation
statistical graphics
dimension reduction
tours
data mining
Author
Di Cook
Published
11 March 2025
Motivation
Susan Wei kicked off NUMBATs for 2025 last Thursday, with a talk about a recently completed project documented in this arxiv paper. The work is motivated by Toy Models of Superposition paper which provides an explanation of hidden layer behaviour of neural networks in the presence of differences induced by sample size, dimension and correlation. Susan’s work relates this to optimisation curiosities and particularly the connection between learning via SGD and Bayesian learning. The example used to illustrate the optimisation quirk is also an interesting from the perspective of high-dimensional visualisation using tours.
About the data example
This 6D data from Susan’s paper approach is generated by random uniform sampling along each of the six axes. The code below can generate these sort of samples.
Code
library(tourr)library(ggplot2)library(tibble)library(dplyr)generate_toy_dataset <-function(num_samples, n, noise_std, seed) {set.seed(seed)# Generate random integers for 'a' (equivalent to JAX's randint) a <-sample(0:(n-1), num_samples, replace =TRUE)# Generate uniform random values for 'lambda_val' lambda_val <-runif(num_samples)# Create input matrix X x <-matrix(0, nrow = num_samples, ncol = n)for (i in1:num_samples) { x[i, a[i] +1] <- lambda_val[i] # +1 since R uses 1-based indexing }# Generate Gaussian noise gaussian_noise <-matrix(rnorm(num_samples * n, mean =0, sd = noise_std), nrow = num_samples, ncol = n)# Compute y y <- x + gaussian_noisereturn(list(x = x, y = y))}
This can be viewed using a grand tour made using the tourr package with the code below:
And here are tours of the data containing different sample sizes, n=50, 100, 500, and shown as orange points. The axes producing each 2D projection in the stream of projections are represented by the line segments and circle. You can see that this data only lies along an axis. It can be interpreted as if one variable has some non-zero value, all other values are zero. It’s a contrived example, but it does arise from analysis of large text data, where each word is unique so when used the other words in the set are not used. This is the data that generates the interesting phase shifts during optimisation.
n=50
n=100
n=500
\(k\)-gons
Here we connect the most extreme point on each axis, and tour again. Because there are 6 axes, and we connect each extreme to every other extremes there are 15 edges in total. By focusing only the convex hull of these for each projection you can see that different polygon (\(k\)-gon) shapes are seen in the 2D projections. The number of vertices visible will range from 3-6, so we have 3-, 4-, 5- and 6-gons. A 6-gon is formed when all six axes extend radially from the centre. It would be an ideal 6-gon (hexagon) if the angles between axes were equal, and the length of the line segments were equal.
Code
load("d50.rda")load("d100.rda")load("d500.rda")mx50 <-c(which.max(dx50$V1),which.max(dx50$V2),which.max(dx50$V3),which.max(dx50$V4),which.max(dx50$V5),which.max(dx50$V6))edges50 <-NULLfor (i in1:5) for (j in (i+1):6) edges50 <-rbind(edges50, c(mx50[i], mx50[j]))colnames(edges50) <-c("from", "to")mx100 <-c(which.max(dx100$V1),which.max(dx100$V2),which.max(dx100$V3),which.max(dx100$V4),which.max(dx100$V5),which.max(dx100$V6))edges100 <-NULLfor (i in1:5) for (j in (i+1):6) edges100 <-rbind(edges100, c(mx100[i], mx100[j]))colnames(edges100) <-c("from", "to")mx500 <-c(which.max(dx500$V1),which.max(dx500$V2),which.max(dx500$V3),which.max(dx500$V4),which.max(dx500$V5),which.max(dx500$V6))edges500 <-NULLfor (i in1:5) for (j in (i+1):6) edges500 <-rbind(edges500, c(mx500[i], mx500[j]))colnames(edges500) <-c("from", "to")
In Susan’s example, the optimisation is training a neural network designed to take 6D data transform it into 2D and then recover the 6D data again. The model is learning the coefficients for the 2D projection, that is, the axes as shown in the representations from a tour. The latter is the same input data with some additional noise. I’m have no idea why this is an interesting problem, except perhaps related to finding a useful 2D representation of the data structure, one that would give the viewer a good chance of recognising the underlying structure of the high-dimensional data. This type of neural network training is the basis of the Toy Models of Superposition paper.
This is the illustration of the optimisation, from Susan’s paper. The loss appears to have almost discrete steps with big decreases between \(k\)-gons. The vertices of polygons (in the convex hull, at least) are considered to be critical points in the 2D projection reached at the different phases of the optimisation.
Random projections of the data
The optimisation is working through the space of all possible 2D projections of the data. Here is a sample of 8 random projections. This is a (tiny) sample of the space that the optimisation is working from.
These are perfect 3-, 4-, 5-, 6-gon shapes, and the transition to a smaller loss should happen when the extra axis gets drawn out from 0. As you have seen there are many in-between shapes, generated by small changes in the projection basis, producing polygons with 3, 4, 5, or 6 vertices which are not regularly placed.
Next
Stay tuned for the next blog post using high-dimensional visualisation that teases apart what the Toy Models of Superposition paper is trying to illustrate.