<aside> 💡 This is a short blog about my understanding for internal behavior of transpose() and why contiguous() is needed after this operation.
</aside>
Transpose() does not alter the memory buffer instead it is just an alteration of accessing it from memory buffer using attribute “stride”.
Memory buffer (a flat, contiguous chunk of memory that has all data).
Suppose tensor ‘a’ = ([ [ 1, 2, 3],
[4, 5, 6 ] ]) , shape (2, 3)
It is a 2-D tensor but memory buffer stores it as [1,2,3,4,5,6]
PyTorch needs a rule to access element from buffer, this rule is nothing but “stride”.
Stride→ how many steps to jump in the memory buffer to move along each axis, in short how many elements to skip to reach a desired element.
Using transpose(), PyTorch does not creates new memory for transposed tensor, instead it changes rule for travelling the buffer (strides).
So, physical memory remains the same just logical memory differs. But how logically? 👇
For the above tensor ‘a’, accessing a[1][0], stride would be (3 , 1)
This is decided by formula → for stride (i, j) and element a[k][l] → buffer_index = k * i + l * j
So, for a[1][1] it is calculated as 1 * 3 + 1 * 1 = 4, go to buffer[4] → element 5
Above was explanation about travelling the memory space. But what transpose actually does it, it alters the strides, transposing tensors basically means transposing its strides.
Memory buffer still stores the original (not transposed one) tensor. So, for transposed ‘a’ 👇
[ [ 1, 4 ],
[2, 5 ],
3 , 6 ]] Shape (3, 2) stride would be (1,3) → “transposed stride“
Computation goes like for accessing a[1][1] from [1,2,3,4,5,6] (same memory layout) → 11 + 13
that is buffer[4] = element 5.
“For contiguous memory, stride has this property: stride[i] = product([i+1:])”
For a contiguous 3D tensor shape [2, 4, 3], the stride would be [12, 3, 1] (applying the above written formula for stride).
But for non-contiguous: the stride will not be product of future shapes anymore !
But because of this transposition, resulting tensor may become non-contiguous.
Because for reading transposed tensor you would need to read row by row [1, 4, 2, 5, 3, 6] but memory is still [1, 2, 3, 4, 5, 6].
How to solve it?
By $.continuous()$ , creates new memory [1, 4, 2, 5, 3, 6] by stride (2, 1).
Its interesting that view() and transpose() are similar with a subtle difference.
view() and transpose() both, they do not alter the memory buffer, so how does view() reshapes it?
view() reshapes tensor to the given 2 dimension by altering strides same as what transpose() does. But requires contiguity. So, after transpose() for contiguous tensor use “contiguous()” then you can use view() as it works on contiguous tensors.

