--- title: Low-memory stochastic backpropagation with multi-channel randomized trace estimation. author: | Mathias Louboutin^1^, Ali Siahkoohi^1^, Rongrong Wang^2^, Felix J. Herrmann^1^\ ^1^ School of Computational Science and Engineering, Georgia Institute of Technology\ ^2^ Department of Mathematics, Michigan State University bibliography: - probeml.bib --- {>> Title is shitty, redo<<} # refs @NEURIPS2019_9015 : pytorch @Flux @lightonproj Trace estimation as part of gaussian process: - @GPyTorch, uses trace estimation for the trace part of gaussian inference gradient approx in ML and other: - @DFA1, @han2019efficient, @frenkel : DFA, only use DFA for last couple dense layers - @oktay2020randomized : princeton RAD, based on unbiased identity matrix, but for subsampled base vectors. Per batch elemetn -> expensive and not as much memory saving. - @wangacccnn : randomized cnn gradient (zero dw, randn dw, top-k dy) but only replace 25% of layers, ~2% error - @adelman2021faster : fast nn with approx, why should keep correct first layer - @nokland19a : Training with errors, per sub-block loss - @li2019happens : compressed sensing cnn LR: - @TroppLR : streaming algo for LR marices Large batch size: - @DBLP : large mini batch facebok (batch times k, lr times k) - @LARS : large batch size ref paper and algo, 32k batch size for imagenet - @Pinckaers : streamed tile based conv for large images (8k) probing: - @Kaperick2019DiagonalEW : master thesis - @cortinovis2020randomized : trace estimate for indefinite matrix - @hutchpp, @Avron @2014cudnn # TODO - [] Add refs - [] redo math - [] Run gpu benchmark (maybe) - [] Cleanup theory - [] Redo some plots ## Abstract Thanks to the combination of state-of-the-art accelerators and highly optimized open software frameworks there has been tremendous progress in the performance of deep neural networks. While these developments have been responsible for many breakthroughs, progress towards solving large scale problems, such as video encoding and semantic segmentation in 3D, is hampered because access to on-premise memory is often limited. Instead of relying on (optimal) checkpointing or invertibility of the network layers – to recover the state variables during backpropagation – we propose to approximate the gradient of convolutional neural nets with a multi-channel randomized trace estimation technique. Compared to other methods, this approach is simple, easy to analyze, and leads to a greatly reduced memory footprint. Even though the randomized trace estimation introduces stochasticity during training, we argue that this is of little consequence as long as the induced errors are of the same order as errors in the gradient due to the use of stochastic gradient descent. We discuss the performance of networks trained with stochastic backpropagation and how the error can be controlled while maximizing memory usage and minimizing computational overhead. ## Introduction - Convolution layer gradients are expensive and main cost of CNNs. - Unbiased approximation shown to be good (RAD) and randomness as well (DFA) - Low memory = larger batch size. Large batch size trending (LARS, facebook large minibatch) (- Lessons learned from PDE adjoint state) - Extremely easy swap in, swap out implementation compatible with any pytorch network. ## Related work RAD Compress sensing DFA Randomized SVD Gaussian inference ## Theory We start with the casting of convolution layer into a linea algebra framework and introduced the randomized trace estimation method. We show here that we can rewrtie the sensitivity of a convolution layer with respect ot its parameter as the computation of the trace of an extremely large outer product. Since this outer product cannot be constructed realistically, we propose an estimator for this trace based on randomized linear algebra method that allows us to compute an unbiased approximation of the sensitivity. From this trace estimation, we can then derive a memory frugal algorithm that only requires to save a fraction of the input during the forward propagation through the layer. ### Single channel We consider first the single channel standard convolution on an input ``\mathbf{X} \in \mathcal{R}^{B \times N}`` with weights ``w_i \in \mathcal{R}^{K \times K}`` where ``B`` is the mini-batch size, ``K`` is the stencil width and ``N`` is the number of pixels in the input image. We can write the convolution in a linear algebra form as follows: ```math {#LAconv} \mathbf{W} &= \sum_{i=1}^{n_w} w_i T_{k(i)} \\ \mathbf{Y} &= \mathbf{W}\mathbf{X} ``` In this formulation, ``\mathbf{T}_{k(i)}`` is the circular shift operator that shifts a vector by ``k(i)`` where ``k(i)`` is the shift corresponding to the ``i^{th}`` weight in ``\mathbf{w}``. While this formulation doesn't represent the computational aspect of the convolution, we can easily derive the gradient with respect to the weight for the convolution in a linear algebra framework. Apllying the chain rule and standard linear algebra identity [ref matrix cookbook] to Equation #LAconv we can write: ```math {#dwtr} \frac{\partial }{\partial w_i} f(\mathbf{W}\mathbf{X}) &= \operatorname{tr}\left(\left(\frac{\partial f(\mathbf{W}\mathbf{X})}{\partial \mathbf{W}}\right)^\top \frac{\partial \mathbf{W}}{\partial w_i}\right) \\ &= \operatorname{tr}\left(\left(\delta \mathbf{Y} \mathbf{X}^\top\right)^\top \mathbf{T}_{k(i)}^\top\right) \\ &= \operatorname{tr}\left(\mathbf{X} \delta \mathbf{Y}^\top \mathbf{T}_{-k(i)}\right) \\ &=\operatorname{tr}_{-k(i)}\left(\mathbf{X} \delta \mathbf{Y}^\top\right), i=1\cdots n_w. ``` Here, ``\operatorname{tr}`` is the trace operator that sums the values of the diagonal of a matrix and ``\operatorname{tr}_{k}`` is the "off-digaonal trace" that sms the value of the off diagonal ``k``. While explicitly computing this outer-product would be unefficient both computationnaly and memory-wise we can obtain an unbiased estimate of the trace via matrix probing techniques [refs]. These methods are designed to estimate the diagonals and traces of matrixes that are either too big to be explicitly formed, or in general for linear operator that are only accessible via their action rahter than their value such as in matrix-free frameworks (sPOT, pyops, put refs). This unbiased estimate of the traces is then obtain via repeated left and right matrix-vector products on carefully chosen random vectors. The trace estimators derives from the following unbiased estimate of the identity matrix: ```math {#trdef} \operatorname{tr}(\mathbf{A})& =\operatorname{tr}\left(\mathbf{A} \mathbf{I}\right)=\operatorname{tr}\left(\mathbf{A} \mathbb{E}\left[\mathbf{z} \mathbf{z}^{\top}\right]\right) \\ & =\mathbb{E}\left[\operatorname{tr}\left(\mathbf{A} \mathbf{z} \mathbf{z}^{\top}\right)\right] =\mathbb{E}\left[\mathbf{z}^{\top} \mathbf{A} \mathbf{z}\right]\\ & \approx \frac{1}{r}\sum_{i=1}^r\left[\mathbf{z}^{\top}_i \mathbf{A} \mathbf{z}_i\right] = \frac{1}{r}\operatorname{tr}\left(\mathbf{Z}^\top \mathbf{A}\mathbf{Z}\right). ``` We inject Equation #dwtr into this estimator to obtain the following unbiased estimator of the gradient with respect to the weights: ```math {#grad_pr} \operatorname{tr}_{-k(i)}(\mathbf{A})& = \operatorname{tr}_{-k(i)}\left(\mathbf{A} \mathbb{E}\left[\mathbf{z} \mathbf{z}^{\top}\right]\right) \approx \frac{1}{r}\sum_{i=1}^r\left[\mathbf{z}^{\top}_i \mathbf{A} \mathbf{T}_{-k(i)}\mathbf{z}_i\right] \\ \delta w_i & = \frac{1}{r} \sum_{j=1}^r \mathbf{z}_j^\top\mathbf{X} \delta \mathbf{Y}^\top\mathbf{T}_{-k(i)}\mathbf{z}_j = \frac{1}{r} \operatorname{tr}(\underbrace{(\mathbf{Z}^\top\mathbf{X})}_{\mathbf{\overline{X}}\in \mathbb{R}^{r\times B}}\delta \mathbf{Y}^\top\mathbf{T}_{-k(i)}\mathbf{Z}). ``` This estimator converges to the true gradient when ``r \rightarrow \infty`` and for a certain class of matrices (symmetric semi-definite positive, SDP) theoretical error bounds can be estimated from the number of probing vectors ``r``. In our case, we do not have any knowledge on ``\mathbf{A}=\mathbf{X} \mathbf{Y}^\top``, but we show in the appendix that we can relax the symmetry hypothesis using the symmetry of the probing algorithm. Additionally, empirical results from the litterature show that while SDP is required for the convergence proof, that hypothesis is not necessary in practice for convergence. Before deriving our trace estimation based algorithm, that takes advantage of the highlighted ``\overline{X}`` in Equation #grad_pr to drastically reduce the memory requirement of convolution layer, we now extend the singe trace estimator we presented to the multi-channel case where instead of the trace of ``\mathbf{A}``, the trace of sub-blocks of this matrix is estimated. ### Multi-channel {>>ML: Add multi channel conv math<<} In general, a convolution layer involves multiple input and output channels. In this case, the weight update consists of the single channel update for all possible pairs of input and output channels. The simplest way to use the trace estimator would be to simply treat each of these pairs independently. However, the gain would be limited computationnally and memory-wise as we will explain in a later section. To further improve the adavntage of our ,ethod, we propose a simultaneous estimator that treats all the channels as one. To do so we consider the full outer product ``\mathbf{X} \mathbf{Y}^\top`` where ``\mathbf{X}, \mathbf{Y}`` are this the matrix of size ``(N * C_i) × B `` and ``(N * C_o) \times B`` with ``C_i, C_o`` the number of input and output channels. We then use our randomized trace estimation to estimate the trace of each ``N \times N`` subblock, that corresponds to a single input-out channel pair. First, we probe the full outer product with a large ``r \times (N * ci)`` probing vector on the left to obtain the probed matrix. We then reshape that matrix into a 3D tensor of size ``N \times C_o \times B`` and apply the right probing and estimate of each trace from each corresponding block in the probing vector. We illustrate the procedure of multi-channel trace estimation of Figure #multich\. #### Figure: {#multich} ![probing size=4](figures/multi_chan.png) : Multi-channel trace estimation. This figure showse ach steps of the algorithm to estimate the trace of a sub block of the outer product. Because we probe the full outer product rather than each individual channel pair, we now introduce cross-channel cross talk that reduces the accuracy of the estimator. To overcome this set back we can either increase the number of probing vecotrs or we can redesign the choice of ``z`` to reduce the cross talk. Recalling Equation #trdef\, we design ``z`` such that ``\mathbb{E}\left[\mathbf{z} \mathbf{z}^{\top}\right] = \mathcal{I}`` still stands while enforcing a block diagonal structure to ``\mathbf{z} \mathbf{z}^{\top}`` to reduce to cross channel cross-talk. We show on Figure #zortho such a block orthogonalized probing matrix ``\mathbf{Z}`` (each column is a probing vector) with its outer product ``\mathbf{Z} \mathbf{Z}^\top`` showign that we can draw random vectors satisfying these properties. We sketch in the appendix {>>rongrong<<} the theoretical advantages of this choice for multi-channel trace estimation. We show convergence results on a synthetized multi-channel matrix on Figure #convmulti that illustrate the better convergence of our block orthogonalized probing matrix over standard purely probing due to the reduced cross-channel cross-talk. #### Figure: {#zortho} ![](figures/zortho.png){width=50%} : Probing matrices and corresponding approximation of the identity for ``C_i=16, C_o=16, r=64, N=256``. We see on the left pannel that the orthogonalization steps leads to a near block diagonal approximation of the identity that does not create cross-talk between different channels. #### Figure: {#TODO2} ![probing size=4](figures/bias/var_grad1_CIFAR10-randX.png){width=50%} : Trace estimate convergence on a synthetize multi channel matrix woth ``C_i=32, C_o=32, N=1024, B=50``. Finally, we consider a more practical case and compute the first gradient of a complete classification network (see CIFAR10 training section) on a random input for the CIFAR10 dataset. We compare the true gradient top three different estimators. First, we estimate the trace for each input-out channel pair independently. THis first case is by design the most accurate while leading to much less memopry reduction. Second, we copmput the multi-channel estuimate with a standard random probing matrix that incurs cross talks while drastically reducing the mmeory usage. FInally we compute the estimate with our block orthogonalized probing matrix that reduces cross-channel cross talk. Wer show the first gradient on Figure #cifarfirst\. This first gradient is computed for a lrge probing size (2048) so that the estimate is visually comparable to the true gradient. In practice, we would use a much lower probing size for most of the training relying un the unbias of the estimate over the number of itwerations. #### Figure: {#c4ifarfirst} ![probing size=256](figures/c4ifarfirst_2048.png){width=100%} : Trace estimate gradient for our convolutionnal network on the CIFAR10 dataset. While some inaccuracy remain, the orthogonalization step leads to a more accurate estimate that handles different scales better. We can see that as expected, estimating the trace per pair leads to a more accurate estimate. Second, we can clearly see that our block orthogonalization improves the estimate compared to a standard random probing matrix, in particular when amplitudes differe between channel pair. The cross talk reduction allows low amplitude channel to be estimated better reducing the influence of high amplitude channels. With our estimator define for practical multi-channel cases, we now highlight the main advantage of our method that is the massive memory reduction incur by computing the gradient with respect to the weights with our unbiased trace estimator. ## Compact memory forward-backward implementation In order to reduce the memory imprint of a convolutional layer, we implemented the proposed method with a compact in memory forward-backward design. This implementation is based on the symmetry of the probing and the trace. We can reformulate the trace formualtion in Equation #dwtr and its unbiased estimate in Equation #grad_pr as: ```math {#dwsplit} \delta w_i = \frac{1}{r} \operatorname{tr}((\mathbf{Z}^\top\mathbf{X})\delta \mathbf{Y}^\top\mathbf{T}_{-k(i)}\mathbf{Z}), i=1\cdots n_w. ``` In this symmetrized expression, the shift are applied to the backpropagated residual that allows us to compute ``overline{\mathbf{X}} =\mathbf{Z}^ \mathbf{X}`` during the forward propagation through the layer. This precomputation then only requires to store the matrix ```\mathbf{\overline{X}}`` of size ``B x r`` that leads to a memory reduction by a factor of ``\frac{N C_i}{r}``. The computation of this expression can be summarized in three steps that optimize the recomputation while requiing a single temporary of size at most equivalent to the input. - 1. ``\bar{\mathbf{X}} =\mathbf{Z}^ \mathbf{X}`` - 2. ``\mathbf{L} = \bar{\mathbf{X}} \mathbf{Y}^\top`` - 3. For each ``i`` ``\delta w_i = \operatorname{tr}(\mathbf{L} \mathbf{T}_{-k(i)}\mathbf{Z})`` For a small image of size ``32x32`` and ``16`` input channels, this implementation leads to a memory reduction by a factor of ``2^{14-p}`` for ``r=2^p`` (X100 for ``r=128``). We then only need to allocate temporary memory for each layer for the probing vector that can be redrwan from a saved random generator seed. The forward-backward algorithm is summarized in Algorithm #ev_fwd_bck\. ### Algorithm: {#ev_fwd_bck} | Forward pass: | 1. Convolution ``\mathbf{y} = C(\mathbf{x}, \mathbf{w})`` | 2. Draw a random seed ``s`` and probing matrix ``\mathbf{Z}(s)`` | 3. Compute and save ``\mathbf{\overline{X}} = \mathbf{Z}(s)^\top\mathbf{X} \in \mathcal{R}^{r \times B}`` | 4. Store ``\overline{\mathbf{X}}, s`` | Backward pass: | 1. Load random seed ``s`` and probed forward ``\overline{\mathbf{X}}`` | 2. Redraw probinf matrix ``\mathbf{Z}(s)`` from ``s`` | 3. Compute backward probe ``\mathbf{L} = \bar{\mathbf{X}} \mathbf{Y}^\top`` | 4. Compute gradient ``\delta w_i = \frac{1}{r} \operatorname{tr}(\mathbf{L} \mathbf{T}_{-k(i)}\mathbf{Z}(s))`` : Forward-backward unbiased estimator via trace estimation. This simple yet powerfull algorithm provides a virtually memory free estimate of the true gradient with respect to the weights. # Validation - **Accuracy**. We look at the accuracy of the obtained gradient against the true gradient for varying batch size, image size and number of probing vectors. - **Biasing**. We verify that the gradient is unbiased using the CIFAR10 dataset computing expectation of our gradient approximation against the true gradient. ## Sample variance We compute the gradient with respect to the filter of the standard image-to-image mean-square error ``\frac{1}{2}||C(X) - Y||^2`` where ``C`` is pure convolution layer ([Flux.jl](https://github.com/FluxML/NNlib.jl)) and ``Y`` is a batch of images from the CIFAR10 dataset. We consider two cases for ``X``. In the first case, ``X`` is a batch drawn from the CIFAR10 dataset as well while in the second case, ``X`` is a random variable drawn from ``\mathcal{N}(0, 1)``. #### Figure: {#bias-cifar10-rand} ![probing size=4](figures/bias/var_grad1_CIFAR10-randX.png){width=50%} ![probing size=8](figures/bias/var_grad2_CIFAR10-randX.png){width=50%}\ ![probing size=16](figures/bias/var_grad3_CIFAR10-randX.png){width=50%} ![probing size=32](figures/bias/var_grad4_CIFAR10-randX.png){width=50%} : Gradients with ``X`` drawn from the normal distribution. #### Figure: {#bias-cifar10} ![probing size=4](figures/bias/var_grad1_CIFAR10.png){width=50%} ![probing size=8](figures/bias/var_grad2_CIFAR10.png){width=50%}\ ![probing size=16](figures/bias/var_grad3_CIFAR10.png){width=50%} ![probing size=32](figures/bias/var_grad4_CIFAR10.png){width=50%} : Gradients with ``X`` drawn from the CIFAR10 dataset. We show these gradients on Figures #bias-cifar10 and #bias-cifar10-rand\. These figures demonstrate three points. First, we can see that an increasing number of probing vector leads to a better estimate of ``\delta W[i, j]`` and a reduced standard deviation making a single sample a more accurate estimates following theoretical expectations from the litterature. Second, we show that our estimate is unbiased as both the mean and mediam matches the true gradient. Finally, these figures show that an increased batch size leads to a more accurate estimator and a reduced variance allowing a smaller number of probing vector, therefore a better perormance, for a larger batch size. ## Standard deviation We now consider the variance of different batches within a dataset. We know that large batch size leads to sharper minimas, and therefore less variance accross different batches on Figure #std-cifar10\. We show here that as expected, our estimate follows this trend as well, and that due to our estimator error, the variance is increased. However, since we can afford larger batch size for similar memory usage, we can obtain lower variance for a given memory budget using a larger batch size. #### Figure: {#std-cifar10} ![var](figures/bias/var_conv_40.png){width=100%} :Standard deviation of the gradients w.r.t the weights for each cnvolution layer of our convolutional network. The standard deviation is computed over 40 mini-batch drawn from the full CIFAR10 dataset. # Performance We now look at the pure computational performance for the computation of a single gradient. We look at the runtime for a single gradient for our julia implementation on CPU against NNLob and our pytorch implementation on GPU against the standard convolution layer. We show here that despite a fairly simplistic implementation, we already see computational advantages of our method for larger problem (batch size, image size, number of channels). These results wshow that a CUDA native implementation could outperform existing state-of-the-art convolution layer implementions makign our algorithm both computationally and memory-wise advantageous. ## Runtime We show on Figure #cpu-bench and #gpu-bench the benchmarked runtime to compute a single gradient with NNlib and with our method for varying image sizes and batch sizes. The benchmark was done for a small (4 =>4) and large number of channel (32 =>32). #### Figure: {#cpu-bench} ![B=4](figures/runtimes/bench_cpu_4_4.png){width=18%} ![B=8](figures/runtimes/bench_cpu_4_8.png){width=18%} ![B=16](figures/runtimes/bench_cpu_4_16.png){width=18%} ![B=32](figures/runtimes/bench_cpu_4_32.png){width=18%} ![B=64](figures/runtimes/bench_cpu_4_64.png){width=18%}\ ![B=4](figures/runtimes/bench_cpu_32_4.png){width=18%} ![B=8](figures/runtimes/bench_cpu_32_8.png){width=18%} ![B=16](figures/runtimes/bench_cpu_32_16.png){width=18%} ![B=32](figures/runtimes/bench_cpu_32_32.png){width=18%} ![B=64](figures/runtimes/bench_cpu_32_64.png){width=18%}\ :CPU benchmark on a *Intel(R) Xeon(R) CPU E3-1270 v6 @ 3.80GHz* node. The left column contains the runtimes for 4 channels and the right column for 32 channels. We can see that for large images and batch sizes, our implementation provides a consequent performance gain. #### Figure: {#gpu-bench} ![B=32](figures/runtimes/gpu_perf_32.png){width=50%} ![B=64](figures/runtimes/gpu_perf_64.png){width=50%}\ ![B=128](figures/runtimes/gpu_perf_128.png){width=50%} ![B=256](figures/runtimes/gpu_perf_256.png){width=50%} : GPU benchmark on a *Tesla K80* (Azure NC6 instance) for a single gradient for varying batch sizes, imagae sizes and number of channel (same number of input and output channels). We obeserve that for larger scale problem we perform as well if not better than state of the art CuDNN kernels. These performace result on CPU are very promising and compete with exisiting high performance implementation. We even outperform standard *im2col* implementation by a factor of up to X10 for large images. On GPU, we do not see any performance advamtages for small problem where CUDNN implementation are as effificient as it can be. However, for larger problem, we see that we start to be competitive as well nearing similar if not better performance than highly optimized hand coded CUDA kernels. ## Network memory Following on the memory gains we introduced and demonstrated for convolutional layers in section #ref, we now consider full neural networks. In general term, the memory gain will be driven by the ratio of convolution layer in the network. Because our implementation is virtually memroy free ( ``\mathcal{O}(Mb)`` memory cost), we can estimate the totla memory gain to be equivalent to the ratio of convolution layer. We demonstrate that this estimates tands in practice on a several mainstream networks on Figure #nn-mem\. #### Figure: {#bias-cifar10} ![](figures/memory/squeezenet1_1.png){width=50%} ![](figures/memory/squeezenet1_0.png){width=50%}\ ![](figures/memory/resnet18.png){width=50%} ![](figures/memory/resnet50.png){width=50%} : Network memory usage for a single gradient. WE show the mory usage for known networks for low and high probe sizes for a fixed input size. We can see that the memroy usage of a network is effectively independent of the number of probing vectors due to the massive mempory gain incured by our estimator. In some cases, the memory usage is a bit higher than half of the true network due to in place `ReLU` layers. The layer would, in default configuration, be memory free relying on the following convolution layer to store the needed parameters for the backpropagation. Because we do not store these parameters anymore without estimator, the memory usage of the in place layer is increased. We show that this can easily be offseted by only storing the necessary prameters as sing bits (`int8` in pytorch that does't support bit arrays). In summary, we the network memory usage is limited by other layers, however our estimator uses in general about half the mmeory of the standard network (for convolutional network) allowign to us tow ork with double the batch size for a fixed memroy budget. # Training This last section verifies that our unbiased estimator can be used to train convolutional networks and leads to good accuracy. We show the training on the MNIST dataset and show that, for large batch size, our estimator provides comparable accuracy to conventional training. ### MNIST We show here two different training we each of our julia and pytorch implementation. We demonstrate here that we achieve accuracy comparable to the same network using true convolution with our low memory estimator in both cases. We also show that our estimator outperforms standard convolution in some cases when using and line search based algorithm. #### XConv.jl We first train on the MNIST dataset with varying batch sizes and numer of probing vectors. ### Table: {#MNIST-batch} | | ``B=64`` | ``B=128`` |``B=256`` | |:---------:|:----------:|:----------:|:----------:| | default | ``0.9905`` | ``0.9898`` | ``0.9901`` | | ``ps=2`` | ``0.9625`` | ``0.9692`` | ``0.9745`` | | ``ps=16`` | ``0.9753`` | ``0.9803`` | ``0.9823`` | | ``ps=64`` | ``0.9777`` | ``0.9723`` | ``0.9782`` | |``ps=256`` | ``0.9718`` | ``0.9706`` | ``0.9791`` | : Training accuracy for varying batch sizes ``B`` and number of probing vectors ``ps`` on the MNIST dataset. We observe in these results the same behavior we saw looking at the accuracy of a single gradient. The accuracy decreases for a low number of probing vector that do not provide an accurate enough estimate despite being unbiased. As we increase the number of probing vector and batch size, the accuracy of the training increases as expected. Globally, we loose less than 1% in accuracy while being memory frugal and computationnaly efficient. #### pyxconv Second we train a convolutionnal network (see appendix, Table #MNISTpy) on the same MNIST dataset using our pytorch implementation. We train 20 epochs once again. In this experiment, we used a novel line search based optimization method (SLS, [@vaswani2019painless]) to train our network. This method relies on stochastic linesearch and the calculation of an extra gradient to improve the convergence. By using this method, we show that our estimate is good enough for a line search base method to converge. This means our approximate gradient is close enough to the true gradient direction for the optimization algorithm to progress in the right directio nwhile satisfying the functional decrease. We trained our network for varying number of probing vectors and we show the accuracy convergence on Figure #mnist-sls\. We can see that we perform as well as the training with the true gradient while using drastically less memory (about X100 for ``r=16``) fr all batch sizes and that we outperform standard training in some cases. Since the difference between the different cases is so small, we can say that our method is on par with standard training, the variations beeing due to the stochastic nature of the training. #### Figure: {#mnist-sls} ![B=4](figures/Acc_mnist.png) : MNIST training for vraying batch sizes and probing sizes. This experiment is ran with the Stochastic Line Search algorithm (SLS, [@vaswani2019painless]) ### CIFAR10 We finally look at a more realistic classifiation problem and we look at the training accuracy of all the methods we introduced in this work. We compare here the reference solution (true gradient) with a batch size of 128 to three different case with our estimator and a batch size of 256 (X2) correspondign to the expected memory gain. First we trained a network with our trace estimator applied to each channel pair individually with 64 probing vectors per pair. This most accurate case uses 32 probing vectors leading to a memory reduction by a factor of `32` for the first two convolution layers and `8` for the last two layers. Second we used our multi-channel estimator with a standard random probing matrix with `256` probing vectors leading to a memory reduction by a factor of `64` for the first two convolution layers and `16` for the last two. Finally we trained with our orthogonalized multi-channel probing matrix with `256` probing vectors as well leading to the same memory reduction than the thrid case. We trained for 100 epochs and show the test accuracy and training loss for these four case on Figure #cifar-train. The training hyperparameters and network architecture are described in the appendix. #### Figure: {#cifar-sgd} ![](figures/cifar10.png) : CIFAR10 training with equivalent memory comapring our proposed method with standard training. The four pannels show the training and test loss at he bottom and training and test accuracy (top 1) at the top after 100 epochs. # Implementation and code availability Our probing algorithm is implemented in both in julia, using BLAS on CPU and CUBALS on GPU for the linear algebra computations, and in pytorch. The Code is available on github. The julia interface is designed so that preexisting networks can be reused as is overloading `rrule` (see [ChainRulesCore](https://github.com/JuliaDiff/ChainRulesCore.jl)) to switch easily between the conventional true gradient. The pytorch implementation devines a new layer that can be swapped for the conventional `torch.nn.Conv2d` in any network. The code and examples are available at [XConv](https://github.com/slimgroup/XConv.jl). # Conclusions - Good performance for large image and/or batchsize - Don't need many probing vector if batchsize lare enough - Fairly suboptimal implementation leads to less impressive results on GPU but can be improved # References # Annexes ## Trace estimation theory ## Networks We dewcribe here the network architecture used in our experiments. The MNIST architeture are standard architectures from the litterature and the CIFAR architecture is an intentionnaly very convolutionnal architecture obtained from @oktay2020randomized. ### Table: {#MNISTjl} Layer | kernel size | Input size (``C_o \times N_x \times N_y``) | Output size (``C_o \times N_x \times N_y``) :---------:|:-----------:|:----------:|:-----------: Conv2d | (3, 3) | Bx1x28x28 | Bx16x28x28 | relu | - | Bx16x28x28 | Bx16x28x28 | Maxpool | (2, 2) | Bx16x28x28 | Bx16x14x14 | Conv2d | (3, 3) | Bx16x14x14 | Bx32x14x14 | relu | - | Bx32x14x14 | Bx32x14x14 | Maxpool | (2, 2) | Bx32x14x14 | Bx32x7x7 | Conv2d | (3, 3) | Bx32x7x7 | Bx32x7x7 | relu | - | Bx32x7x7 | Bx32x7x7 | Maxpool | (2, 2) | Bx32x7x7 | Bx32x3x3 | flatten | - | Bx32x3x3 | Bx288 | Dense | - | Bx288 | Bx10 | : MNIST network and sizes for the trainging with XConv.jl for a batch size ``B`` ### Table: {#MNISTpy} Layer | kernel size | Input size (``C_o \times N_x \times N_y``) | Output size (``C_o \times N_x \times N_y``) :---------:|:-----------:|:----------:|:-----------: Conv2d | (3, 3) | Bx1x28x28 | Bx32x28x28 | relu | - | Bx32x28x28 | Bx32x28x28 | Conv2d | (3, 3) | Bx32x28x28 | Bx64x28x28 | relu | - | Bx64x28x28 | Bx64x28x28 | Maxpool | (2, 2) | Bx64x28x28 | Bx64x14x14 | Dropout | - | Bx64x14x14 | Bx64x14x14 | flatten | - | Bx64x14x14 | Bx12544 | Dense | - | Bx12544 | Bx128 | relu | - | Bx128 | Bx128 | Dropout | - | Bx128 | Bx128 | Dense | - | Bx128 | Bx10 | Log Softmax| - | Bx10 | Bx10 | : MNIST network and sizes for the trainging with pyxconv on the MNIST dataset for a batch size `B`. ### Table: {#CIFARpy} Layer | kernel size | Input size (``C_o \times N_x \times N_y``) | Output size (``C_o \times N_x \times N_y``) :---------:|:-----------:|:----------:|:-----------: Conv2d | (5, 5) | Bx3x32x32 | Bx16x32x32 | relu | - | Bx16x32x32 | Bx16x32x32 | Conv2d | (5, 5) | Bx16x32x32 | Bx32x32x32 | relu | - | Bx32x32x32 | Bx32x32x32 | AvgPool | (2, 2) | Bx32x32x32 | Bx32x16x16 | Conv2d | (5, 5) | Bx32x16x16 | Bx32x16x16 | relu | - | Bx32x16x16 | Bx32x16x16 | Conv2d | (5, 5) | Bx32x16x16 | Bx32x16x16 | relu | - | Bx32x16x16 | Bx32x16x16 | AvgPool | (2, 2) | Bx32x16x16 | Bx32x8x8 | flatten | - | Bx32x8x8 | Bx2048 | Dense | - | Bx2048 | Bx10 | Log Softmax| - | Bx10 | Bx10 | : CIFAR10 network and sizes for the trainging with pyxconv on the MNIST dataset for a batch size `B`. ## Training parameters ### MNIST with XConv.jl - Tesla K80 - Network is a standard convolution network: - Conv(1=>16) -> MaxPool -> Conv(16=>32) -> MaxPool -> Conv(32=>32) - MaxPool -> flatten -> dense - 20 epochs` - ADAM with initial learning rate of ``.003`` - MNSIST dataset for varying batch size and probing size - Always keep first layer intact since the input is already in memory for free. - Julia ### MNIST with pyxconv - Tesla K80 - Network is a standard convolution network: - check which network was used - 20 epochs` - SLS with inital learning rate of 1 (SLS default) - MNSIST dataset for varying batch size and probing size - Always keep first layer intact since the input is already in memory for free.