Skip to main content

Thank you for visiting nature.com. You are using a browser version with limited support for CSS. To obtain the best experience, we recommend you use a more up to date browser (or turn off compatibility mode in Internet Explorer). In the meantime, to ensure continued support, we are displaying the site without styles and JavaScript.

Improving performance of deep learning models with axiomatic attribution priors and expected gradients

Abstract

Recent research has demonstrated that feature attribution methods for deep networks can themselves be incorporated into training; these attribution priors optimize for a model whose attributions have certain desirable properties—most frequently, that particular features are important or unimportant. These attribution priors are often based on attribution methods that are not guaranteed to satisfy desirable interpretability axioms, such as completeness and implementation invariance. Here we introduce attribution priors to optimize for higher-level properties of explanations, such as smoothness and sparsity, enabled by a fast new attribution method formulation called expected gradients that satisfies many important interpretability axioms. This improves model performance on many real-world tasks where previous attribution priors fail. Our experiments show that the gains from combining higher-level attribution priors with expected gradients attributions are consistent across image, gene expression and healthcare datasets. We believe that this work motivates and provides the necessary tools to support the widespread adoption of axiomatic attribution priors in many areas of applied machine learning. The implementations and our results have been made freely available to academic communities.

A preprint version of the article is available at ArXiv.

Access options

Rent or Buy article

Get time limited or full article access on ReadCube.

from$8.99

All prices are NET prices.

Fig. 1: EG is a feature attribution method designed to be regularized during training.
Fig. 2: Pixel attribution prior improves saliency map smoothness and increases robustness of MNIST classifier to noise.
Fig. 3: Pixel attribution prior improves saliency map smoothness and increases robustness of CIFAR-10 classifier to noise.
Fig. 4: Graph attribution prior improves test accuracy and biological relevance of anticancer drug response prediction model.
Fig. 5: Sparse attribution prior builds sparser and more accurate healthcare mortality models.

Data availability

The data for all experiments and figures in the paper are publicly available. A downloadable version of the dataset used for the sparsity experiment, as well as links to download the datasets used in the image and graph prior experiments, is available at https://github.com/suinleelab/attributionpriors. Data for the benchmarks were published as part of ref. 57 and can be accessed at https://github.com/suinleelab/treeexplainer-study/tree/master/benchmark.

Code availability

Implementations of attribution priors for Tensorflow and PyTorch are available at https://github.com/suinleelab/attributionpriors. This repository also contains code reproducing main results from the paper. The specific version of code used in this paper is archived at ref. 58.

References

  1. 1.

    Lundberg, S. M. & Lee, S.-I. A unified approach to interpreting model predictions. In Advances in Neural Information Processing Systems Vol. 30, 4765–4774 (NeurIPS, 2017).

  2. 2.

    Sundararajan, M., Taly, A. & Yan, Q. Axiomatic attribution for deep networks. In Proc. 34th International Conference on Machine Learning Vol. 70, 3319–3328 (Journal of Machine Learning Research, 2017).

  3. 3.

    Štrumbelj, E. & Kononenko, I. Explaining prediction models and individual predictions with feature contributions. Knowl. Inf. Syst. 41, 647–665 (2014).

    Article  Google Scholar 

  4. 4.

    Datta, A., Sen, S. & Zick, Y. Algorithmic transparency via quantitative input influence: theory and experiments with learning systems. In 2016 IEEE Symposium on Security and Privacy (SP) 598–617 (IEEE, 2016).

  5. 5.

    Lundberg, S. M. et al. From local explanations to global understanding with explainable AI for trees. Nat. Mach. Intell. 2, 56–67 (2020).

    Article  Google Scholar 

  6. 6.

    Lundberg, S. M. et al. Explainable machine-learning predictions for the prevention of hypoxaemia during surgery. Nat. Biomed. Eng. 2, 749–760 (2018).

    Article  Google Scholar 

  7. 7.

    Sayres, R. et al. Using a deep learning algorithm and integrated gradients explanation to assist grading for diabetic retinopathy. Ophthalmology 126, 552–564 (2019).

    Article  Google Scholar 

  8. 8.

    Zech, J. R. et al. Variable generalization performance of a deep learning model to detect pneumonia in chest radiographs: a cross-sectional study. PLoS Med. 15, e1002683 (2018).

    Article  Google Scholar 

  9. 9.

    Ross, A. S., Hughes, M. C. & Doshi-Velez, F. Right for the right reasons: training differentiable models by constraining their explanations. In Proc. 26th International Joint Conference on Artificial Intelligence 2662–2670 (IJCAI, 2017).

  10. 10.

    Schramowski, P. et al. Making deep neural networks right for the right scientific reasons by interacting with their explanations. Nat. Mach. Intell. 2, 476–486 (2020).

    Article  Google Scholar 

  11. 11.

    Ilyas, A. et al. Adversarial examples are not bugs, they are features. In Advances in Neural Information Processing Systems Vol. 32 (NeurIPS, 2019).

  12. 12.

    Liu, F. & Avci, B. Incorporating priors with feature attribution on text classification. In Proc. of the 57th Annual Meeting of the Association for Computational Linguistics (ACL) 6274–6283 (2019).

  13. 13.

    Chen, J., Wu, X., Rastogi, V., Liang, Y. & Jha, S. Robust attribution regularization. In Advances in Neural Information Processing Systems Vol. 32 (NeurIPS, 2019).

  14. 14.

    Rieger, L., Singh, C., Murdoch, W. J. & Yu, B. Interpretations are useful: penalizing explanations to align neural networks with prior knowledge. In Proc. 37th International Conference on Machine Learning (eds. Daumé III, H. & Singh, A.) 8116–8126 (ICML, 2020).

  15. 15.

    LeCun, Y., Cortes, C. & Burges, C. MNIST Handwritten Digit Database (AT&T Labs) http://yann.lecun.com/exdb/mnist (2010)

  16. 16.

    Yu, F., Xu, Z., Wang, Y., Liu, C. & Chen, X. Towards robust training of neural networks by regularizing adversarial gradients. Preprint at https://arxiv.org/abs/1805.09370 (2018).

  17. 17.

    Jakubovitz, D. & Giryes, R. Improving DNN robustness to adversarial attacks using Jacobian regularization. In Proc. European Conference on Computer Vision (ECCV) (eds. Ferrari, V., Hebert, M., Sminchisescu, C. & Weiss, Y.) 514–529 (ECCV, 2018).

  18. 18.

    Roth, K., Lucchi, A., Nowozin, S. & Hofmann, T. Adversarially robust training through structured gradient regularization. Preprint at https://arxiv.org/abs/1805.08736 (2018).

  19. 19.

    Selvaraju, R. R. et al. Grad-CAM: visual explanations from deep networks via gradient-based localization. In Proc. IEEE International Conference on Computer Vision 618–626 (IEEE, 2017).

  20. 20.

    Ross, A. S. & Doshi-Velez, F. Improving the adversarial robustness and interpretability of deep neural networks by regularizing their input gradients. In Thirty-second AAAI Conference on Artificial Intelligence Vol. 32 1 (AAAI, 2018).

  21. 21.

    Smilkov, D., Thorat, N., Kim, B., Viégas, F. & Wattenberg, M. Smoothgrad: removing noise by adding noise. Preprint at https://arxiv.org/abs/1706.03825 (2017).

  22. 22.

    Fong, R. C. & Vedaldi, A. Interpretable explanations of black boxes by meaningful perturbation. In Proc. IEEE International Conference on Computer Vision 3429–3437 (IEEE, 2017).

  23. 23.

    Krizhevsky, A. et al. Learning Multiple Layers of Features from Tiny Images Technical Report (Citeseer, 2009).

  24. 24.

    Simonyan, K. & Zisserman, A. Very deep convolutional networks for large-scale image recognition. In 3rd International Conference on Learning Representations (eds. Bengio, Y. & LeCun, Y.) (ICLR, 2015).

  25. 25.

    Recht, B., Roelofs, R., Schmidt, L. & Shankar, V. Do ImageNet classifiers generalize to ImageNet? Proc. of the 36th International Conference on Machine Learning Vol. 97, 5389–5400 (PMLR, 2019).

  26. 26.

    Tsipras, D., Santurkar, S., Engstrom, L., Turner, A. & Madry, A. Robustness may be at odds with accuracy. In 7th International Conference on Learning Representations (ICLR, 2019).

  27. 27.

    Zhang, H. et al. Theoretically principled trade-off between robustness and accuracy. In Proc. 36th International Conference on Machine Learning Vol. 97, 7472–7482 (PMLR, 2019).

  28. 28.

    Cheng, W., Zhang, X., Guo, Z., Shi, Y. & Wang, W. Graph-regularized dual Lasso for robust eQTL mapping. Bioinformatics 30, i139–i148 (2014).

    Article  Google Scholar 

  29. 29.

    Tyner, J. W. et al. Functional genomic landscape of acute myeloid leukaemia. Nature 562, 526–531 (2018).

    Article  Google Scholar 

  30. 30.

    Greene, C. S. et al. Understanding multicellular function and disease with human tissue-specific networks. Nat. Genet. 47, 569–576 (2015).

    Article  Google Scholar 

  31. 31.

    Kipf, T. N. & Welling, M. Semi-supervised classification with graph convolutional networks. In 5th International Conference on Learning Representations (ICLR, 2017).

  32. 32.

    Subramanian, A. et al. Gene set enrichment analysis: a knowledge-based approach for interpreting genome-wide expression profiles. Proc. Natl Acad. Sci. USA 102, 15545–15550 (2005).

    Article  Google Scholar 

  33. 33.

    Benjamini, Y. & Hochberg, Y. Controlling the false discovery rate: a practical and powerful approach to multiple testing. J. R. Stat. Soc. B 57, 289–300 (1995).

    MathSciNet  MATH  Google Scholar 

  34. 34.

    Liu, J. et al. Meis1 is critical to the maintenance of human acute myeloid leukemia cells independent of MLL rearrangements. Ann. Hematol. 96, 567–574 (2017).

    Article  Google Scholar 

  35. 35.

    Valk, P. J. M. et al. Prognostically useful gene-expression profiles in acute myeloid leukemia. N. Engl. J. Med. 350, 1617–1628 (2004).

    Article  Google Scholar 

  36. 36.

    Feng, J. & Simon, N. Sparse-input neural networks for high-dimensional nonparametric regression and classification. Preprint at https://arxiv.org/abs/1711.07592 (2017).

  37. 37.

    Scardapane, S., Comminiello, D., Hussain, A. & Uncini, A. Group sparse regularization for deep neural networks. Neurocomputing 241, 81–89 (2017).

    Article  Google Scholar 

  38. 38.

    Ross, A., Lage, I. & Doshi-Velez, F. The neural lasso: local linear sparsity for interpretable explanations. In Workshop on Transparent and Interpretable Machine Learning in Safety Critical Environments, 31st Conference on Neural Information Processing Systems (2017).

  39. 39.

    Shrikumar, A., Greenside, P. & Kundaje, A. Learning important features through propagating activation differences. In Pro. 34th International Conference on Machine Learning Vol. 70, 3145–3153 (Journal of Machine Learning Research, 2017).

  40. 40.

    Hurley, N. & Rickard, S. Comparing measures of sparsity. IEEE Trans. Inf. Theory 55, 4723–4741 (2009).

    MathSciNet  Article  Google Scholar 

  41. 41.

    Zonoobi, D., Kassim, A. A. & Venkatesh, Y. V. Gini index as sparsity measure for signal reconstruction from compressive samples. IEEE J. Sel. Top. Signal Process. 5, 927–932 (2011).

    Article  Google Scholar 

  42. 42.

    Miller, H. W. Plan and Operation of the Health and Nutrition Examination Survey, United States, 1971–1973 DHEW publication no. 79-55071 (PHS) (Department of Health, Education, and Welfare, 1973).

  43. 43.

    Binder, A., Montavon, G., Lapuschkin, S., Müller, K.-R. & Samek, W. Layer-wise relevance propagation for neural networks with local renormalization layers. In International Conference on Artificial Neural Networks (eds. Villa, A.E.P., Masulli, P. & Rivero, A.J.P.) 63–71 (Springer, 2016).

  44. 44.

    Friedman, E. J. Paths and consistency in additive cost sharing. Int. J. Game Theory 32, 501–518 (2004).

    MathSciNet  Article  Google Scholar 

  45. 45.

    Zhang, H., Cisse, M., Dauphin, Y. N. & Lopez-Paz, D. mixup: beyond empirical risk minimization. In 6th International Conference on Learning Representations (ICLR, 2018).

  46. 46.

    Bardsley, J. M. Laplace-distributed increments, the Laplace prior, and edge-preserving regularization. J. Inverse Ill Posed Probl. 20, 271–285 (2012).

    MathSciNet  Article  Google Scholar 

  47. 47.

    Abadi, M. et al. Tensorflow: a system for large-scale machine learning. In 12th USENIX Symposium on Operating Systems Design and Implementation (OSDI ’16) 265–283 (2016).

  48. 48.

    Lou, Y., Zeng, T., Osher, S. & Xin, J. A weighted difference of anisotropic and isotropic total variation model for image processing. SIAM J. Imaging Sci. 8, 1798–1823 (2015).

    MathSciNet  Article  Google Scholar 

  49. 49.

    Shi, Y. & Chang, Q. Efficient algorithm for isotropic and anisotropic total variation deblurring and denoising. J. Appl. Math. 2013, 797239 (2013).

    MathSciNet  MATH  Google Scholar 

  50. 50.

    Liu, S. & Deng, W. Very deep convolutional neural network based image classification using small training sample size. In 2015 3rd IAPR Asian Conference on Pattern Recognition (ACPR) 730–734 (IEEE, 2015).

  51. 51.

    Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I. & Salakhutdinov, R. Dropout: a simple way to prevent neural networks from overfitting. J. Mach. Learn. Res. 15, 1929–1958 (2014).

    MathSciNet  MATH  Google Scholar 

  52. 52.

    Kingma, D. P. & Ba, J. In 3rd International Conference on Learning Representations (eds. Bengio, Y. & LeCun, Y.) (ICLR, 2015).

  53. 53.

    Virtanen, P. et al. SciPy 1.0: fundamental algorithms for scientific computing in Python. Nat. Methods 17, 261–272 (2020).

    Article  Google Scholar 

  54. 54.

    Preuer, K. et al. DeepSynergy: predicting anti-cancer drug synergy with deep learning. Bioinformatics 34, 1538–1546 (2018).

    Article  Google Scholar 

  55. 55.

    Tibshirani, R. Regression shrinkage and selection via the Lasso. J. R. Stat. Soc. B 58, 267–288 (1996).

    MathSciNet  MATH  Google Scholar 

  56. 56.

    Pedregosa, F. et al. Scikit-learn: machine learning in Python. J. Mach. Learn. Res. 12, 2825–2830 (2011).

    MathSciNet  MATH  Google Scholar 

  57. 57.

    Lundberg, S. M. et al. Explainable AI for trees:from local explanations to global understanding. Preprint at https://arxiv.org/abs/1905.04610 (2019).

  58. 58.

    Sturmfels, P., Erion, G. & Janizek, J. D. suinleelab/attributionpriors: Nature Machine Intelligence code. Zenodo https://doi.org/10.5281/zenodo.4608599 (2021).

Download references

Acknowledgements

The results published here are partially based on data generated by the Cancer Target Discovery and Development (CTD2) Network (https://ocg.cancer.gov/programs/ctd2/data-portal) established by the National Cancer Institute’s Office of Cancer Genomics. The authors received funding from the National Science Foundation (DBI-1759487 (S.-I.L.), DBI-1552309 (J.D.J., G.E., S.-I.L.), DGE-1256082 (S.M.L.)); American Cancer Society (RSG-14-257-01-TBG (J.D.J., P.S., S.-I.L)); and National Institutes of Health (R01AG061132 (J.D.J, P.S., S.-I.L), R35GM128638 (G.E., S.-I.L), F30HL151074-01 (G.E., S.-I.L), 5T32GM007266-46 (J.D.J, G.E.)).

Author information

Affiliations

Authors

Contributions

G.E., J.D.J., P.S. and S.M.L. conceived the study. G.E., J.D.J. and P.S. designed algorithms and experiments. P.S. and J.D.J. implemented core libraries for the research. G.E., J.D.J. and P.S. wrote code for and ran the experiments, plotted figures and contributed to the writing. S.M.L. contributed to the writing. S.-I.L. supervised research and method development, and contributed to the writing.

Corresponding author

Correspondence to Su-In Lee.

Ethics declarations

Competing interests

The authors declare no competing interests.

Additional information

Peer review informationNature Machine Intelligence thanks Ronny Luss, Andrew Ross and the other, anonymous, reviewer(s) for their contribution to the peer review of this work.

Publisher’s note Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.

Supplementary information

Supplementary Information

Supplementary Sections A–J and Figs. 1–20.

Rights and permissions

Reprints and Permissions

About this article

Verify currency and authenticity via CrossMark

Cite this article

Erion, G., Janizek, J.D., Sturmfels, P. et al. Improving performance of deep learning models with axiomatic attribution priors and expected gradients. Nat Mach Intell (2021). https://doi.org/10.1038/s42256-021-00343-w

Download citation

Search

Quick links

Nature Briefing

Sign up for the Nature Briefing newsletter — what matters in science, free to your inbox daily.

Get the most important science stories of the day, free in your inbox. Sign up for Nature Briefing