DARTS: Differentiable Architecture Search

ICLR 2019

Abstract

In a differential manner

Based on the continuous relaxation of the architecture representation, allowing efficient search of the architecture using gradient search

Being orders of magnitude faster than state-of-the-art non-differentiable techniques

Implement publicly available

Introduction

Existing architecture search algorithms computationally demanding:

  • 2000 GPU days of RL
  • 3150 GPU days of evolution

Several approaches for speeding up have been proposed:

  • imposing a particular structure of the search space
  • weights or performance prediction for each individual architecture
  • weight sharing/inheritance across multiple architectures

But the fundamental challenge of scalability remains, the fact that architecture search is treated as a black-box optimization problem over a discrete domain, which leads to a large number of architecture evaluations required

Instead of searching over a discrete set of candidate architectures, we relax the search space to be continuous, then using gradient descent

Yet it is generic enough handle both convolution and recurrent architectures

While prior works seek to fine-tune a specific aspect of an architecture, such as filter shapes or branching patterns in a convolutional network, DARTS is able to learn high-performance architecture building blocks with complex graph topologies within a rich search space

Contributions

  • novel algorithm for differentiable network architecture search based on bilevel algorithm, both convolution and recurrent arichitectures
  • highly competitive results on CIFAR-10 and PTB
  • remarkable efficiency improvement
  • architectures learned by DARTS on CIFAR-10 and PTB are transferable to ImageNet and WikiText-2

Search space

computation cell

The learned cell could either be stacked to form a convolutional network or recursively connected to form a recurrent network

A cell is a directed acyclic graph consisting of an ordered sequence of N nodes. Node $x^{(i)}$ is a latent representation. Directed edge $(i,j)$ operation $o^{(i,j)}$ that transforms $x^{(i)}$

Cell have two input nodes and a single output node
$$
x^{(j)}=\sum_{i<j}o^{(i,j)}(x^{(i)})
$$
DARTS

The task of learning the cell therefore reduces to learning the operation on its edges

Continuous Relaxation and Optimization

Summary:

Continuous relaxation scheme for our search space which leads to a differentiable learning objective for the jointly optimization of the architecture and its weights

$\mathcal{O}$ be a set of candidate operations(e.g., convolution, max pooling, zero)

Relax the categorical choice of a particular operation to a softmax over all possible operations
$$
\overline{o}^{(i, j)}(x)=\sum_{o \in \mathcal{O}} \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)} o(x)
$$

vector $\alpha^{(i,j)}$ operation mixing weights for a pair of nodes $(i,j)$

Learning a set of continuous variables $\alpha = \alpha^{(i,j)}$
$$
o^{(i, j)}=\operatorname{argmax}_{o \in \mathcal{O}} \alpha_{o}^{(i, j)}
$$
Jointly learn the architecture $\alpha$ and the weights $w$ within all the mixed operations

Optimize the validation loss, but using gradient descent

Find $\alpha^$ minimize the validation loss $\mathcal{L}_{val}\left(w^{}, \alpha^{}\right)$, where the weights $w^$ minimize the training loss $w^{}=\operatorname{argmin}_{w} \mathcal{L}_{t r a i n}\left(w, \alpha^{}\right)$

Bilevel optimization problem:
$$
\begin{array}{cl}{\min _{\alpha}} & {\mathcal{L}_{v a l}\left(w^{}(\alpha), \alpha\right)} \
{\text { s.t. }} & {w^{
}(\alpha)=\operatorname{argmin}_{w} \mathcal{L}_{\text {train}}(w, \alpha)}
\end{array}
$$
Architecture $\alpha$ could be viewed as a special type of hyperparamter

Approximate architecture gradient

Summary:

Approximation technique to make the algorithm computationally feasible and efficient

Simple approximation scheme:
$$
\begin{aligned} & \nabla_{\alpha} \mathcal{L}_{v a l}\left(w^{*}(\alpha), \alpha\right) \ \approx & \nabla_{\alpha} \mathcal{L}_{v a l}\left(w-\xi \nabla_{w} \mathcal{L}_{t r a i n}(w, \alpha), \alpha\right) \end{aligned}
$$

$\xi$ is the learning rate for a step of inner optimization, approximate $w^{*}(\alpha)$ by adapting $w$ using only a single training step

Related techniques have been used in meta-learning for model transfer, gradient based hyperparameter tuning, unrolled GAN

DARTS-Algorithm.png

Convergence problem: it is able to reach a fixed point with a suitable choice of $\xi$

Apply chain rule to the approximate architecture gradient:
$$
\nabla_{\alpha} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right)-\xi \nabla_{\alpha, w}^{2} \mathcal{L}_{t r a i n}(w, \alpha) \nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right)
$$
where: $w^{\prime}=w-\xi \nabla_{w} \mathcal{L}_{t r a i n}(w, \alpha)$, denotes the weights for a one-step forward model

The second part of the expression above contains an expensive matrix-vector product in its second term

Using the finite difference approximation:
$$
\begin{array}{l}{w^{ \pm}=w \pm \epsilon \nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right) . \text { Then: }} \\
{\qquad \nabla_{\alpha, w}^{2} \mathcal{L}_{\text {train}}(w, \alpha) \nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right) \approx \frac{\nabla_{\alpha} \mathcal{L}_{\text {train}}\left(w^{+}, \alpha\right)-\nabla_{\alpha} \mathcal{L}_{\text {train}}\left(w^{-}, \alpha\right)}{2 \epsilon}}
\end{array}
$$

First order approximation:

When $\xi=0$ , the architecture gradient is given by $\nabla_{\alpha} \mathcal{L}_{v a l}(w, \alpha)$ , corresponding to the simple heuristic of optimizing the validation loss by assuming the current w is the same as $w^*(α)$

This lead to speed up but empirically worse performance

Deriving discrete architectures

To form each node in the discrete architecture, we retain the top-k strongest operations among all non-zero candidate operations collected from all the previous nodes

Experiments and results

Consist of two stages:

  • architecture search: search for the cell architectures using DARTS, and determine the best cells based on their validation performance
  • architecture evaluation: use these cells to construct larger architectures, which we train from scratch and report their performance on the test set

Investigate the transferability of the best cells

on CIFAR-10

$\mathcal{O}$: 3*3, 5*5 separable conv, 3*3, 5*5 dilated separable conv, 3*3 max pooling, 3*3 average pooling, identify, zero

ReLU-conv-BN order for conv operation, earch separable conv is always applied twice

Conv cell consists of N=7 nodes. Network is then formed by stacking multiple cell together. Cells located at the 1/3 and 2/3 of the total depth of the network are reduction cells

Architecture encoding therefore is $(\alpha_{normal}, \alpha_{reduce})$

on PENN TREEBANK

linear transformations: tanh, relu, sigmoid activations

recurrent cell consists of N=12 nodes

Enable batch normalization in each node to prevent gradient explosion during architecture search, and disable it during architecture evaluation

Recurrent network consists of only a single cell

Architecture evaluation

Run DARTS four times with different random seeds and pick the best cell based on its validation performance

To evaluate the selected architecture, we randomly initialize its weights(weights learned during search process are discarded), train if from scratch, and report its performance on the test set

Results analysis

DARTS achieving comparable results with the state of the art using three orders of magnitude less computation resources

Slightly longer search time

Alternative Optimization Strategies

Two alternative trying, but worse performance:

  • $\alpha$ and $w$ are jointly optimized over the union of training and validation sets using coordinate descent
  • optimize $\alpha$ simultaneously with $w$ (without alteration) using SGD, again over all the data available(training + validation)

In DARTS, $\alpha$ is not directly optimized on the training set

Conclusion

There are many interesting directions to improve DARTS further.

For example, the current method may suffer from discrepancies between the continuous architecture encoding and the derived discrete architecture. This could be alleviated, e.g., by annealing the softmax temperature (with a suitable schedule) to enforce one-hot selection.

It would also be interesting to investigate performance-aware architecture derivation schemes based on the shared parameters learned during the search process.