I came across this problem on LeetCode when it appeared in one of the daily challenges.

The problem asks you to write a function that takes an integer n and returns and n by n matrix filled with elements from 1 to n^2 in spiral order.

so for example if n = 4 the output should be:

1 2 3 4
12 13 14 5
11 16 15 6
10 9 8 7

The problem is very simple. A typical solution would be to traverse the matrix layer by layer from the ouside in in a spiral fashion keeping track of the current number’s index. However there is an inherent problem with this typical solution and that is cache locality.

If you look at this spiral access pattern of the matrix you will find that it jumps around in memory. This is because matrices are either stored row wise or column wise in memory. Row wise meaning each row of the matrix is stored one after the other in memory, and similarly column wise means each column is stored one after the other in memory.

Now how a CPU works is that it loads a block of memory that you application is working on into the cache. If you really would like to optimize the performance of your program then one thing you may do is reducing cache misses. A cache miss is when the CPU tries to lookup up a memory location and doesn’t find it in the cache so it has to go to the actual memory (RAM) to lookup this memory location.

Now assume, for the sake of discussion, that the cache memory can only store one row of the matrix at a time (this is a valid assumption if the matrix is large) and that the matrix is stored row wise in memory. That means whenever we try to access a row other than the row that was loaded in the cache, we are going to have a cache miss. With this spiral access pattern, whenever we are traversing the columns of a layer (left most or right most column of a layer) assuming we traversing the matrix layer by layer from the outside in, then we are going to have a cache miss with each cell of the columns (exept the first and last cells in the column).

To minimize the cache misses the solution is very simple we should traverse the matrix row wise. But that is easier said than done. I have spent about an hour trying to find a formula that given a row index and column index it returns the index of that cell in the matrix if it was traversed in a spiral fashion. That code is below.

I think you definitely shouldn’t try to come up with such a solution in an interview. However, I think that noticing and mentioning the cache locality issue with the spiral access pattern and suggesting a solution for it, shows that you know more about computer architecture and performance optimization and I believe it is going to impress your interviewer :D

In this problem n is <= 20 so a matrix of 400 cells can easily fit in the cache at once, and probably my code would be slower since it is doing more calculations.

One other benifit of this row wise traversal is that it can be parellelized across threads or machines depending or where the matrix is stored.

Here is the code:

class Solution {
    int n;
    public int[][] generateMatrix(int n) {
        this.n = n;
        int[][] matrix = new int[n][n];
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                matrix[i][j] = indexInMatrix(i, j) + 1;
        return matrix;
    }
    
	// returns the 0 based index of the cell given by (row, col) if the matrix was traversed in a spiral fashion
    private int indexInMatrix(int row, int col) {
        int layer = layerIdx(row, col);
        int indexInLayer = indexInLayer(row, col);
        return cellsInOuterMostLayers(layer) + indexInLayer(row, col);
    }
    
	// returns the 0-based index of the cell given by (row, col) in its layer as if the layer was unwinded into a single row
	// ( and the layer is traversed in a spiral fashion)
    private int indexInLayer(int row, int col) {
        int layer = layerIdx(row, col);
        int edgeSize = n - 1 - layer * 2;
        // top row
        if (row == layer && col < n - 1 - layer)
            return col - layer;
        // right col
        if (col == n - 1 - layer)
            return (edgeSize) * 1 + (row - layer);
        // bottom row
        if (row == n - 1 - layer)
            return (edgeSize) * 2 + (n - 1 - layer - col);
        // left col
        return (edgeSize) * 3 + (n - 1 - layer - row);
    }
    
	// returns the layer index of the cell given by (row, col) where 0 is the outermost layer
    private int layerIdx(int row, int col) {
       return Math.min(
           Math.min(row, col),
           Math.min(n - row - 1, n - col - 1)
       );
    }
	
    // returns the total number of cells in the outermost layers up to but not including the given layer index (layer index is 0-indexed)
    private int cellsInOuterMostLayers(int layer) {
        return 4 * layer * ( n - layer);
		// F(0) = 0
		// F(1) = 4 * (n - 1)
		// F(2) = 4 * (n - 1 + n - 3) = 4 * (2n - (1 + 3))
		// F(3) = 4 * (n - 1 + n - 3 + n - 5) = 4 * (3n - (1 + 3 + 5))
		// F(x) = 4 * (x * n - x * x)                  ( x * x = the sum of the first x odd numbers :D)
		// F(x) = 4 * x * (n - x)
    }
}

If you would like to leave a comment please visit my post for this problem on LeetCode here