ICLR 2019
Abstract
In a differential manner
Based on the continuous relaxation of the architecture representation, allowing efficient search of the architecture using gradient search
Being orders of magnitude faster than state-of-the-art non-differentiable techniques
Implement publicly available
Introduction
Existing architecture search algorithms computationally demanding:
- 2000 GPU days of RL
- 3150 GPU days of evolution
Several approaches for speeding up have been proposed:
- imposing a particular structure of the search space
- weights or performance prediction for each individual architecture
- weight sharing/inheritance across multiple architectures
But the fundamental challenge of scalability remains, the fact that architecture search is treated as a black-box optimization problem over a discrete domain, which leads to a large number of architecture evaluations required
Instead of searching over a discrete set of candidate architectures, we relax the search space to be continuous, then using gradient descent
Yet it is generic enough handle both convolution and recurrent architectures
While prior works seek to fine-tune a specific aspect of an architecture, such as filter shapes or branching patterns in a convolutional network, DARTS is able to learn high-performance architecture building blocks with complex graph topologies within a rich search space
Contributions
- novel algorithm for differentiable network architecture search based on bilevel algorithm, both convolution and recurrent arichitectures
- highly competitive results on CIFAR-10 and PTB
- remarkable efficiency improvement
- architectures learned by DARTS on CIFAR-10 and PTB are transferable to ImageNet and WikiText-2
Search space
computation cell
The learned cell could either be stacked to form a convolutional network or recursively connected to form a recurrent network
A cell is a directed acyclic graph consisting of an ordered sequence of N nodes. Node $x^{(i)}$ is a latent representation. Directed edge $(i,j)$ operation $o^{(i,j)}$ that transforms $x^{(i)}$
Cell have two input nodes and a single output node
$$
x^{(j)}=\sum_{i<j}o^{(i,j)}(x^{(i)})
$$
The task of learning the cell therefore reduces to learning the operation on its edges
Continuous Relaxation and Optimization
Summary:
Continuous relaxation scheme for our search space which leads to a differentiable learning objective for the jointly optimization of the architecture and its weights
$\mathcal{O}$ be a set of candidate operations(e.g., convolution, max pooling, zero)
Relax the categorical choice of a particular operation to a softmax over all possible operations
$$
\overline{o}^{(i, j)}(x)=\sum_{o \in \mathcal{O}} \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)} o(x)
$$
vector $\alpha^{(i,j)}$ operation mixing weights for a pair of nodes $(i,j)$
Learning a set of continuous variables $\alpha = \alpha^{(i,j)}$
$$
o^{(i, j)}=\operatorname{argmax}_{o \in \mathcal{O}} \alpha_{o}^{(i, j)}
$$
Jointly learn the architecture $\alpha$ and the weights $w$ within all the mixed operations
Optimize the validation loss, but using gradient descent
Find $\alpha^$ minimize the validation loss $\mathcal{L}_{val}\left(w^{}, \alpha^{}\right)$, where the weights $w^$ minimize the training loss $w^{}=\operatorname{argmin}_{w} \mathcal{L}_{t r a i n}\left(w, \alpha^{}\right)$
Bilevel optimization problem:
$$
\begin{array}{cl}{\min _{\alpha}} & {\mathcal{L}_{v a l}\left(w^{}(\alpha), \alpha\right)} \
{\text { s.t. }} & {w^{}(\alpha)=\operatorname{argmin}_{w} \mathcal{L}_{\text {train}}(w, \alpha)}
\end{array}
$$
Architecture $\alpha$ could be viewed as a special type of hyperparamter
Approximate architecture gradient
Summary:
Approximation technique to make the algorithm computationally feasible and efficient
Simple approximation scheme:
$$
\begin{aligned} & \nabla_{\alpha} \mathcal{L}_{v a l}\left(w^{*}(\alpha), \alpha\right) \ \approx & \nabla_{\alpha} \mathcal{L}_{v a l}\left(w-\xi \nabla_{w} \mathcal{L}_{t r a i n}(w, \alpha), \alpha\right) \end{aligned}
$$
$\xi$ is the learning rate for a step of inner optimization, approximate $w^{*}(\alpha)$ by adapting $w$ using only a single training step
Related techniques have been used in meta-learning for model transfer, gradient based hyperparameter tuning, unrolled GAN
Convergence problem: it is able to reach a fixed point with a suitable choice of $\xi$
Apply chain rule to the approximate architecture gradient:
$$
\nabla_{\alpha} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right)-\xi \nabla_{\alpha, w}^{2} \mathcal{L}_{t r a i n}(w, \alpha) \nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right)
$$
where: $w^{\prime}=w-\xi \nabla_{w} \mathcal{L}_{t r a i n}(w, \alpha)$, denotes the weights for a one-step forward model
The second part of the expression above contains an expensive matrix-vector product in its second term
Using the finite difference approximation:
$$
\begin{array}{l}{w^{ \pm}=w \pm \epsilon \nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right) . \text { Then: }} \\
{\qquad \nabla_{\alpha, w}^{2} \mathcal{L}_{\text {train}}(w, \alpha) \nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right) \approx \frac{\nabla_{\alpha} \mathcal{L}_{\text {train}}\left(w^{+}, \alpha\right)-\nabla_{\alpha} \mathcal{L}_{\text {train}}\left(w^{-}, \alpha\right)}{2 \epsilon}}
\end{array}
$$
First order approximation:
When $\xi=0$ , the architecture gradient is given by $\nabla_{\alpha} \mathcal{L}_{v a l}(w, \alpha)$ , corresponding to the simple heuristic of optimizing the validation loss by assuming the current w is the same as $w^*(α)$
This lead to speed up but empirically worse performance
Deriving discrete architectures
To form each node in the discrete architecture, we retain the top-k strongest operations among all non-zero candidate operations collected from all the previous nodes
Experiments and results
Consist of two stages:
- architecture search: search for the cell architectures using DARTS, and determine the best cells based on their validation performance
- architecture evaluation: use these cells to construct larger architectures, which we train from scratch and report their performance on the test set
Investigate the transferability of the best cells
Architecture search
on CIFAR-10
$\mathcal{O}$: 3*3, 5*5 separable conv, 3*3, 5*5 dilated separable conv, 3*3 max pooling, 3*3 average pooling, identify, zero
ReLU-conv-BN order for conv operation, earch separable conv is always applied twice
Conv cell consists of N=7 nodes. Network is then formed by stacking multiple cell together. Cells located at the 1/3 and 2/3 of the total depth of the network are reduction cells
Architecture encoding therefore is $(\alpha_{normal}, \alpha_{reduce})$
on PENN TREEBANK
linear transformations: tanh, relu, sigmoid activations
recurrent cell consists of N=12 nodes
Enable batch normalization in each node to prevent gradient explosion during architecture search, and disable it during architecture evaluation
Recurrent network consists of only a single cell
Architecture evaluation
Run DARTS four times with different random seeds and pick the best cell based on its validation performance
To evaluate the selected architecture, we randomly initialize its weights(weights learned during search process are discarded), train if from scratch, and report its performance on the test set
Results analysis
DARTS achieving comparable results with the state of the art using three orders of magnitude less computation resources
Slightly longer search time
Alternative Optimization Strategies
Two alternative trying, but worse performance:
- $\alpha$ and $w$ are jointly optimized over the union of training and validation sets using coordinate descent
- optimize $\alpha$ simultaneously with $w$ (without alteration) using SGD, again over all the data available(training + validation)
In DARTS, $\alpha$ is not directly optimized on the training set
Conclusion
There are many interesting directions to improve DARTS further.
For example, the current method may suffer from discrepancies between the continuous architecture encoding and the derived discrete architecture. This could be alleviated, e.g., by annealing the softmax temperature (with a suitable schedule) to enforce one-hot selection.
It would also be interesting to investigate performance-aware architecture derivation schemes based on the shared parameters learned during the search process.