Augmenting interpretable models with large language models during training

Recent large language models (LLMs), such as ChatGPT, have demonstrated remarkable prediction performance for a growing array of tasks. However, their proliferation into high-stakes domains and compute-limited settings has created a burgeoning need for interpretability and efficiency. We address this need by proposing Aug-imodels, a framework for leveraging the knowledge learned by LLMs to build extremely efficient and interpretable prediction models. Aug-imodels use LLMs during fitting but not during inference, allowing complete transparency and often a speed/memory improvement of greater than 1000x for inference compared to LLMs. We explore two instantiations of Aug-imodels in natural-language processing: Aug-Linear, which augments a linear model with decoupled embeddings from an LLM and Aug-Tree, which augments a decision tree with LLM feature expansions. Across a variety of text-classification datasets, both outperform their non-augmented, interpretable counterparts. Aug-Linear can even outperform much larger models, e.g. a 6-billion parameter GPT-J model, despite having 10,000x fewer parameters and being fully transparent. We further explore Aug-imodels in a natural-language fMRI study, where they generate interesting interpretations from scientific data.


Introduction
Large language models (LLMs) have demonstrated remarkable predictive performance across a growing range of diverse tasks [1][2][3].However, their proliferation has led to two burgeoning problems.First, like most deep neural nets, LLMs have become increasingly difficult to interpret, often leading to them being characterized as black boxes and debilitating their use in high-stakes applications such as science [4], medicine [5], and policy-making [6].Moreover, the use of black-box models such as LLMs has come under increasing scrutiny in settings where users require explanations or where models struggle with issues such as fairness [7] and regulatory pressure [8].Second, black-box LLMs have grown to massive sizes, incurring enormous energy costs [9] and making them costly and difficult to deploy, particularly in low-compute settings (e.g.edge devices).
As an alternative to large black-box models, transparent models, such as generalized additive models [10] and decision trees [11] can maintain complete interpretability.Additionally, transparent models tend to be dramatically more computationally efficient than LLMs.While transparent models can sometimes perform as well as black-box LLMs [12][13][14][15], in many settings (such as natural-language processing (NLP)), there is often a considerable gap between the performance of transparent models and black-box LLMs.
We address this gap by proposing Augmented Interpretable Models (Aug-imodels), a framework to leverage the knowledge learned by LLMs to build extremely efficient and interpretable models.Specifically, we Fig. 1 Aug-imodels use an LLM to augment an interpretable model during fitting but not inference (toy example for moviereview classification).(A) Aug-GAM fits an augmented additive model by extracting fixed-size embeddings for decoupled ngrams in a given sequence, summing them, and using them to train a supervised linear model.(B) At test time, Aug-GAM can be interpreted exactly as a generalized additive model.A linear coefficient for each ngram in the input is obtained by taking the dot product between the ngram's embedding and the shared vector w.(C) Aug-Tree improves each split of a decision tree during fitting by (D) augmenting each keyphrase found by CART with similar keyphrases generated by an LLM.
define an Aug-imodel as a method that leverages an LLM to fit an interpretable model, but does not use the LLM during inference.This allows complete transparency and often a substantial efficiency improvement (both in terms of speed and memory).Aug-imodels can address shortcomings in existing transparent models by using the world knowledge present in modern LLMs, such as information about feature correlations.We explore two instantiations of Aug-imodels: (i) Aug-GAM, which augments a generalized additive model via decoupled embeddings from an LLM and (ii) Aug-Tree, which augments a decision tree with improved features generated by calling an LLM (see Fig 1).At inference time, both are completely transparent and efficient: Aug-GAM requires only summing coefficients from a fixed dictionary while Aug-Tree requires checking for the presence of keyphrases in an input.
Across a variety of natural-language-processing datasets, both proposed Aug-imodels outperform their non-augmented counterparts.Aug-GAM can even outperform much larger models (e.g. a 6-billion parameter GPT-J model [16]), despite having 10,000x fewer parameters and no nonlinearities.We further explore Augimodels in a natural-language fMRI context, where we find that they can predict well and generate interesting interpretations.In what follows, Sec 2 formally introduces Aug-imodels, Sec 3 and Sec 4 shows results for predictive performance and interpretation, Sec 5 explores Aug-imodels in an fMRI prediction setting, Sec 6 reviews related work, and Sec 7 concludes with a discussion.

Aug-imodels methodology: Aug-GAM and Aug-Tree
In this section, Sec 2.1 overviews limitations of existing transparent methods, Sec 2.2 introduces Aug-GAM, and Sec 2.3 introduces Aug-Tree.

Limitations of existing transparent methods
We are given a dataset of n natural language strings X text and corresponding labels y ∈ R n .In transparent modeling, often each string x is represented by a bag-of-words, in which each feature x i is a binary indicator (or count) of the presence of a single token (e.g. the word good ).To model interactions between tokens, one can instead use a bag-of-ngrams representation, whereby each feature is formed by concatenating multiple tokens (e.g. the phrase not good ).Using a bag-of-ngrams representation maps X text to a feature matrix X ∈ R n×p , where p is the number of unique ngrams in X text .While this representation enables interpretability, the number of ngrams in a dataset grows exponentially with the size of the ngram (how many tokens it contains) and the vocab-size; even for a modest vocab-size of 10,000 tokens, the number of possible trigrams is already 10 12 .This makes it difficult for existing transparent methods to model all trigrams without overfitting.Moreover, existing transparent methods completely fail to learn about ngrams not seen in the training set.
Preliminaries: GAMs.Generalized additive models, or GAMs [10] take the form: where (x 1 , x 2 , . . ., x p ) are the input features (i.e.ngrams), y is the target variable, g(•) is the link function (e.g., logistic function) and each f i is a univariate shape function with E [f i ] = 0. Due to the function's additivity, each component function f i can be interpreted independently.Generalized linear models, such as logistic regression, are a special form of GAMs where each f i is restricted to be linear.
Preliminaries: decision trees.CART [11] fits a binary decision tree via recursive partitioning.When growing a tree, CART chooses for each node t the split s that maximizes the impurity decrease in the responses y.For a given node t, the impurity decrease has the expression ∆(s, t, y) : where t L and t R denote the left and right child nodes of t respectively, and ȳt , ȳt L , ȳt R denote the mean responses in each of the nodes.For classification, h(•, •) corresponds to the Gini impurity, and for regression, h(•, •) is the mean-squared error.Each split s is a partition of the data based on a feature in X.To grow the tree, the splitting process is repeated recursively for each child node until a stopping criteria (e.g. a max depth) is satisfied.At inference time, we predict the response of an example by following its path from the root to a leaf and then predicting with the mean value found in that leaf.

Aug-GAM method description
To remedy the issues with the GAM model in Eq. ( 1), we propose Aug-GAM, an intuitive model which leverages a pre-trained LLM to extract a feature representation φ(x i ) for each input ngram x i .This allows learning only a single linear weight vector w with a fixed dimension (which depends on the embedding dimension produced by the LLM), regardless of the number of ngrams.As a result, Aug-GAM can learn efficiently as the number of input features grows, and can also infer coefficients for unseen features.The fitted model is still a GAM, ensuring that the model can be precisely interpreted as a linear function of its inputs: Fitting Aug-GAM is similar to the popular approach of finetuning a single linear layer on top of LLM embeddings.However, it requires extra steps that separately extract/embed each ngram to keep the contributions to the prediction strictly additive across ngrams (see Fig 1A): (i) Extracting ngrams: To ensure input ngrams are interpretable, ngrams are constructed using a word-level tokenizer (here, spaCy [17]).We select the size of ngrams to be used via cross-validation.(ii) Extracting embeddings: Each ngram is fed through the LLM to retrieve a fixed-size embedding.1 (iii) Summing embeddings: The embeddings of each ngram in the input are summed to yield a single fixed-size vector, ensuring additivity of the final model.(iv) Fitting the final linear model to make predictions: A linear model is fit on the summed embedding vector.We choose the link function g to be the logit function (or the softmax for multi-class) for classification and the identity function for regression.In both cases, we add 2 regularization over the parameters w in Eq. (3).
Computational considerations.During fitting, Aug-GAM is inexpensive to fit as (1) the pre-trained LLM is not modified in any way, and can be any existing off-the-shelf model and (2) Aug-GAM only requires fitting a fixed-size linear model.After training, the model can be converted to a dictionary of scalar coefficients for each ngram, where the coefficient is the dot product between the ngram's embedding and the fitted weight vector w (Fig 1B).This makes inference extremely fast and converts the model to have size equal to the number of fitted ngrams.When new ngrams are encountered at test-time, the coefficients for these ngrams can optionally be inferred by again taking the dot product between the ngram's embedding and the fitted weight vector w;

Aug-Tree method description
Aug-Tree improves upon a CART decision tree by augmenting features with generations from an LLM.This helps capture correlations between ngrams, including correlations with ngrams that are not present in the training data.Aug-Tree does not modify the objective in Eq. ( 2) but rather modifies the procedure for fitting each individual split s (Fig 1D).While CART restricts each split to a single ngram, Aug-Tree lets each split fit a disjunction of ngrams (e.g.ngram1 ∨ ngram2 ∨ ngram3 ).The disjunction allows a split to capture sparse interactions, such as synonyms in natural language.This can help mitigate overfitting by allowing individual splits to capture concrete concepts, rather than requiring many interacting splits.When fitting each split, Aug-Tree starts with the ngram which maximizes the objective in Eq. ( 2), just as CART would do, e.g.not good.Then, we query an LLM to generate similar ngrams to include in the split, e.g.bad, poor, awful, ..., horrendous.Specifically, we query GPT-3 (text-davinci-003) [1] with the prompt Generate 100 concise phrases that are very similar to the keyphrase:\nKeyphrase: "{keyphrase}"\n1.and parse the outputs into a list of ngrams.We greedily screen each ngram by evaluating the impurity of the split when including the ngram in the disjunction; we then exclude any ngram which increases the split's impurity, resulting in a shortened list of ngrams, e.g.bad, poor, dull.See extended algorithm details in Algorithm B2.
Computational considerations.As opposed to Aug-GAM, Aug-Tree uses an LLM API rather than LLM embeddings, which may be more desirable depending on access to compute.The number of LLM calls required is proportional to the number of nodes in the tree.During inference, the LLM is no longer needed and making a prediction simply requires checking an input for the presence of specific ngrams along one path in the tree.Storing an Aug-GAM model requires memory proportional to the number of raw strings associated with tree splits, usually substantially reducing memory over the already small Aug-GAM model.We explore variations of Aug-Tree (such as using LLM embeddings rather than an API) in Sec B.
3 Results: Prediction performance

Experimental setup
Datasets.Table 1 shows the datasets we study: 4 widely used text classification datasets spanning different domains (e.g.classifying the emotion of tweets [18], the sentiment of financial news sentences [19], or the sentiment of movie reviews [20,21]), and 1 scientific text regression dataset (described in Sec 5) [22].Across datasets, the number of unique ngrams grows quickly from unigrams to bigrams to trigrams.Moreover, many ngrams appear very rarely, e.g., in the Financial Phrasebank (FPB) dataset, 91% of trigrams appear only once in the training dataset.Aug-GAM settings.We compare Aug-GAM to four interpretable baseline models: Bag of ngrams, TF-IDF (Term frequency-inverse document frequency) [23], GloVE [24] 2 , and a model trained on BERT embeddings for each unigram in the input (which can be viewed as running Aug-GAM with only unigrams).
We use BERT (bert-base-uncased) [3] as the LLM for extracting embeddings, after finetuning on each dataset. 3In each case, a model is fit via cross-validation on the training set (to tune the amount of 2 regularization added) and its accuracy is evaluated on the test set.
Aug-Tree settings.We compare Aug-Tree to two decision tree baselines: CART [11] and ID3 [26], and we use bigram features.In addition to individual trees, we fit bagging ensembles, where each tree is created using a bootstrap sample the same size as the original dataset (as done in Random Forest [27]) and has depth 8.This hurts interpretability, but can improve predictive performance and calibration.For simplicity, we run Aug-GAM only in a binary classification setting; to do so, we take two opposite classes from each multiclass dataset (Negative/Positive for FPB and Sadness/Joy for Emotion).

Aug-GAM text-classification performance
Generalization as a function of ngram size.Fig 2A shows the test accuracy of Aug-GAM as a function of the ngram size used for computing features.Aug-GAM outperforms the interpretable baselines, achieving a considerable increase in accuracy across three of the four datasets.Notably, Aug-GAM accuracy increases with ngram size, whereas the accuracy of baseline methods decreases or plateaus.This is likely due to Aug-GAM fitting only a fixed-size parameter vector, helping to prevent overfitting.
Comparing Aug-GAM performance with black-box baselines.Table 2 shows the test accuracy results for various models when choosing the size of ngrams via cross-validation.Compared with interpretable baselines, Aug-GAM shows considerable gains on three of the datasets and only a minor gain on the tweet dataset (Emotion), likely because this dataset requires fitting less high-order interactions.
Compared with the zero-shot performance of the much larger GPT models (6-billion parameter GPT-J [16] and 175-billion parameter GPT-3, text-davinci-002 [1]) 4 , Aug-GAM outperforms GPT-J.Aug-GAM lags slightly behind GPT-3 for binary classification problems (Rotten Tomatoes and SST2 ) but outperforms GPT-3 for multi-class classification problems (FPB and Emotion).The best black-box baseline (a BERT finetuned model) outperforms Aug-GAM by 4%-6% accuracy.This is potentially a reasonable tradeoff in settings where interpretability, speed, or memory bottlenecks are critical.
Complementing Aug-GAM with a black-box model.In some settings, it may be useful to use Aug-GAM on relatively simple samples (for interpretability/memory/speed) but relegate relatively difficult samples to a black-box model.To study this setting, we first predict each sample with Aug-GAM, then assess  its confidence (how close its predicted probability for the top class is to 1).If the confidence is above a prespecified threshold, we use the Aug-GAM prediction.Otherwise, we compute the sample's prediction using a finetuned BERT model.Fig 2B shows the accuracy for the entire test set as we vary the percentage of samples predicted with Aug-GAM.Since Aug-GAM yields probabilities that are reasonably calibrated (see Fig A1), rather than the accuracy linearly interpolating between Aug-GAM and BERT, a large percentage of samples can be predicted with Aug-GAM while incurring little to no drop in accuracy.For example, when using Aug-GAM on 50% of samples, the average drop in test accuracy is only 0.0053.
Tradeoffs between accuracy and efficiency.In cases involving inference memory/speed, Aug-GAM can be converted to a dictionary of coefficients, whose size is the number of ngrams that appeared in training (see Table 1).For a trigram model, this yields roughly a 1,000x reduction in model size compared to the ∼110 million trainable parameters in BERT, with much room for further size reduction (e.g.simply removing coefficients for trigrams that appear only once yields another 10-fold size reduction).Inference is nearly instantaneous, as it requires looking up coefficients in a dictionary and then a single sum (and does not require a GPU).Sec A.1 explores accuracy/efficiency tradeoffs.For example, Aug-GAM performance degrades gracefully when the model is compressed by removing its smallest coefficients.In fact, the test accuracy of Aug-GAM models trained with 4-grams on the Emotion and Financial phrasebank datasets actually improves after removing 50% of the original coefficients (Fig A2A).Additionally, one can vary the size of ngrams used at test-time without a severe performance drop, potentially enabling compressing the model by orders of

Aug-Tree generalization performance
We now investigate the predictive performance of Aug-Tree, measured by the test ROC AUC on the previous text classification datasets altered for binary classification.Note that the performance of all tree-based methods on the studied datasets is below the performance of the GAM methods in Sec 3.2 (see Table B5 for a direct comparison).Nevertheless, Aug-Tree models maintain potential advantages, such as storing far fewer parameters, clustering important features together, and better modeling long-range interactions.
Fig 3A shows the performance for Aug-Tree as a function of tree depth compared to decision tree baselines.Aug-Tree shows improvements that are sometimes small (e.g. for Financial phrasebank ) and sometimes relatively large (e.g. for Emotion).Fig 3B shows the performance of a bagging ensemble of trees with different tree methods used as the base estimator.Here, using Aug-Tree shows a reliable and significant gain across all datasets compared to ensembles of baseline decision-tree methods.This suggests that LLM augmentation may help to diversify or decorrelate individual trees in the ensemble.
Varying Aug-imodels settings.We investigate many variations in the settings used for Aug-imodels.Table B4 shows variations of different hyperparameters for Aug-Tree, such as using embeddings or datasetspecific prompts to expand keyphrases.Table A2 shows how generalization accuracy changes when the LLM used to extract embeddings for Aug-GAM is varied, or different layers / ngram selection techniques are used.Across the variations, embeddings from finetuned models yield considerably better results than embeddings from non-finetuned models.

Interpreting fitted models
In this section, we interpret fitted Aug-imodels.We first inspect an Aug-GAM model fitted using unigram and bigram features on the SST2 dataset which achieves 84% test accuracy.Next, we analyze the keyphrase expansions made in fitted Aug-Tree bagging ensembles.
Fitted Aug-GAM coefficients match human scores.A fitted Aug-GAM model can be interpreted for a single prediction (i.e.getting a score for each ngram in a single input, as in Fig 1) or for an entire dataset (i.e. by inspecting its fitted coefficients).Fig 4A visualizes the fitted Aug-GAM coefficients (i.e. the contribution to the prediction w T φ(x i )) with the largest absolute values across the SST2 dataset.To show a diversity of ngrams, we show every fifth ngram.The fitted coefficients are semantically reasonable and many contain strong interactions (e.g.not very is assigned to be negative whereas without resorting is assigned to be positive).This form of model visualization allows a user to audit the model with prior knowledge.Moreover,   these coefficients are exact and therefore avoid summarizing interactions, making them considerably more faithful than post-hoc methods, such as LIME [29] and SHAP [30] (see Sec A.2 for a comparison).Fig 4B compares the fitted Aug-GAM coefficients to human-labeled sentiment phrase scores for unigrams/bigrams in SST (note: these continuous scores are separate from the binary sentence labels used for training in the SST2 dataset).Both are centered, so that 0 is neutral sentiment and positive/negative values correspond to positive/negative sentiment, respectively.There is a strong positive correlation between the coefficients and the human-labeled scores (Spearman rank correlation ρ = 0.63), which considerably improves over coefficients from a bag-of-bigrams model trained on SST2 (ρ = 0.39).
Inferred Aug-GAM coefficients for unseen ngrams match human scores.One strength of Aug-GAM is its ability to infer linear coefficients for ngrams that were not seen during training.Whereas baseline models generally assign each unknown ngram the same coefficient (e.g.0), Aug-GAM can effectively assign these new ngrams accurate coefficients.Aug-Tree augmented splits contain relevant phrases.A fitted Aug-Tree model can be easily interpreted for a single prediction (i.e. by inspecting the ngrams that triggered relevant splits) or by visualizing the entire tree (e.g.Fig 1C).Here, we additionally analyze how well each ngram found by CART matches the augmented ngrams found by the LLM; the better this match is, the easier it is to interpret a split.
Table 3 shows examples of the ngrams which were most frequently augmented when fitting a bagging ensemble of 40 Aug-Trees to the four text-classification datasets in Table 1.Added ngrams seem qualitatively reasonable, e.g. the keyphrase good expands to fine, highly, solid, ..., valuable.We evaluate how well the expansions match the original CART ngram via human evaluation scores.Human evaluators are given Table 3 Examples of most frequently augmented ngrams for each dataset when fitting an ensemble of 40 Aug-Trees.Human scores measure the similarity between an ngram and its expansion.They range from 1 (worst match) to 5 (best match), and the baseline score when ngrams and expansions are randomly paired and evaluated is 1.3±0.1.Error bars show standard error of the mean.Abbreviations: FPB = Financial Phrasebank, RT = Rotten tomatoes.the original ngram and the added ngrams, then instructed "You are given a keyphrase along with related keyphrases.On a scale of 1 (worst) to 5 (best), how well do the related keyphrases match the example keyphrase?"5 .Table 3 shows that the average human score for splits in each dataset is consistently greater than 4.This is substantially higher than the baseline score of 1.3 assigned by human evaluators when 15 ngrams and expansions are randomly paired and evaluated.Table B3 gives more details on ngram expansions.

Analyzing fMRI data with Aug-imodels
We now explore Aug-imodels in a real-world neuroscience context.A central challenge in neuroscience is understanding how and where semantic concepts are represented in the brain.To meet this challenge, one line of study predicts the response of different brain voxels (i.e.small regions in space) to natural-language stimuli.We analyze data from a recent study in which the authors collect functional MRI (fMRI) responses as human subjects listen to hours of narrative stories [22].The fMRI responses studied here contain 95,556 voxels from a single subject, with 9,461 time points used for training/cross-validation and 291 time points used for testing.We predict the continuous response for each voxel at each time point using the 20 words that precede the time point. 6Seminal work on this task found that linear models of word vectors could effectively predict voxel responses [31], and more recent work shows that LLMs can further improve predictive performance [32,33].Aug-GAM is well-suited to this task, as it combines low-level word information with the contextualized information present in higher-order ngrams, both of which have been found to contribute to fMRI representations of text [34].Going beyond prediction performance, Fig 5C investigates a simple example of how Aug-GAM could help interpret an underlying brain region.We first select the voxel which is best-predicted by Aug-GAM (achieving a test correlation of 0.76) and then visualize the largest fitted Aug-GAM coefficients for that voxel.These correspond to which ngrams increase the activity of the fMRI voxel the most.Interestingly, these ngrams qualitatively correspond to understandable concepts: questioning, e.g."are you sure", often combined with disbelief/incredulity, e.g."wow I never".Fig 5D shows two examples of voxels that are better predicted by Aug-Tree than Aug-GAM (Aug-Tree yields test correlations of 0.35 and 0.36).These two voxels are both related to someone speaking, but they seem to depend on interactions between the noun (me or you) and the verb (says).To elicit a large response both must be present, something which is difficult to capture in additive models, even with ngrams, since these words may not be close together in a sentence.This interpretation approach could be applied more rigorously to generate hypotheses for text inputs that activate brain regions, and then testing them with followup fMRI experiments.

Background and related work
GAMs.There is a large literature on additive models being used for interpretable modeling.This includes generalized additive models (GAMs) [10], which have achieved strong performance in various domains by modeling individual component functions/interactions using regularized boosted decision trees [35] and more recently using neural networks [36].However, existing GAM methods are limited in their ability to model the high-order feature interactions that arise in NLP.Meanwhile, NLP has seen great success in models which build strong word-level representations, e.g.word2vec [37,38], GloVe [24], FastText [39] and ELMo [40].By replacing such models with LLM embeddings, Aug-GAM enables easily modeling ngrams of different lengths without training a new model.Moreover, unlike earlier methods, LLMs can incorporate information about labels into the embeddings (e.g. by first finetuning an LLM on a downstream prediction task).
Decision trees.There is a long history of greedy methods for fitting decision trees, e.g., CART [11], ID3 [26], and C4.5 [41].More recent work has explored fitting trees via global optimization rather than greedy algorithms [42][43][44]; this can improve performance given a fixed tree size but incurs a high computational cost.Other recent studies have improved trees after fitting through regularization [45] or iterative updates [46].Beyond trees, there are many popular classes of rule-based models, such as rule sets [47], rule lists [48,49], and tree sums [15].Aug-Tree addresses a common problem shared by rule-based approaches: modeling the sparse, correlated features that are common in tasks such as text classification.
Beyond fitting a single tree, tree ensembles such as Random Forest [27], gradient-boosted trees [50], XGBoost [51], and BART [52], have all shown strong predictive performance in diverse settings.These ensembling approaches are compatible with Aug-Tree, as it can be used as the base estimator in any of these approaches.
Interpreting/distilling neural networks.The work here is related to studies that aim to make neural networks more interpretable.For example, models can make predictions by comparing inputs to prototypes [53,54], by predicting intermediate interpretable concepts [55][56][57], using LLMs to extract prompt-based features [58,59], distilling a neural network into a mostly transparent model [60] or distilling into a fully transparent model (e.g.adaptive wavelets [13] or an additive model [61]).Separately, many works use neural network distillation to build more efficient (but still black-box) neural network models, e.g.[62,63].
Feature and feature-interaction importances.Loosely related to this work are post-hoc methods that aim to help understand a black-box model, i.e. by providing feature importances using methods such as LIME [29], SHAP [64], and others [65,66].However, these methods lose some information by summarizing the model and suffer from issues with summarizing interactions [67,68].Slightly more related are works which aim to explain feature interactions or transformations in neural networks [69][70][71], but these works fail to explain the model as a whole and are again less reliable than having a fully transparent model.

Discussion
Aug-imodels provide a promising direction towards future methods that reap the benefits of both LLMs and transparent models in NLP: high accuracy along with interpretability/efficiency.This potentially opens the door for introducing LLM-augmented models in high-stakes domains, such as medical decision-making and in new applications on compute-limited hardware.Aug-imodels is currently limited to applications for which an effective LLM is available, and thus may not work well for very esoteric NLP tasks.However, as LLMs improve, the predictive performance of Aug-imodels should continue to improve and expand to more diverse NLP tasks.More generally, Aug-imodels can be applied to domains outside of NLP where effective foundation models are available (e.g. computer vision or protein engineering).
Aug-imodels can be readily extended to new model forms beyond additive models and trees.Other transparent models, such as rule lists, rule sets, and prototype-based models could all potentially benefit from LLM augmentation during training time.In all these cases, LLM augmentation could use LLM embeddings (as is done in Aug-GAM), use LLM generations (as is done in Aug-Tree), or use LLMs in new ways.Aug-GAM could be augmented by building on the nonlinearity present in GAMs such as the explainable boosting machine [35], to nonlinearly transform the embedding for each ngram with a model before summing to obtain the final prediction.Additionally, Aug-GAM could fit long-range interaction terms as opposed to only ngrams.Aug-Tree could leverage domain knowledge to engineer more meaningful prompts for expanding ngrams or for extracting relevant ngrams.Both models can be further studied to improve their compression (potentially with LLM-guided compression techniques) or to extend their capabilities to tasks beyond classification/regression, such as sequence prediction or outlier detection.We hope that the introduction of Aug-imodels can help push improved performance prediction into high-stakes applications, improve interpretability for scientific data, and reduce unnecessary energy/compute usage.as positive and negative labels and require the LLM to rank the two options.Human-written prompts are adapted to this template from open-source prompts available through PromptSource [28].Fig A2A shows the prediction performance when compressing the Aug-GAM model (fit using 4-grams and finetuned BERT) by setting the coefficients with the smallest magnitudee to zero.Some models require only a few coefficients to perform well and some models (e.g. the Emotion and Financial phrasebank models) predict more accurately when using less than 50% of the original coefficients.

A.2 Comparison with post-hoc feature importance
The coefficients learned by Aug-GAM often differ from importances assigned by post-hoc feature-importance methods.Aug-GAM learns a single coefficient for each ngram across the dataset, allowing for auditing/editing the model with visualizations such as Fig 4 .In contrast, popular methods for post-hoc feature importance, such as LIME [29] and SHAP [30] yield importance scores that vary based on the context in each input.This can be useful for debugging complex nonlinear models, but these scores (i) are approximations, (ii) must summarize nonlinear feature interactions, and (iii) vary across predictions, making transparent models preferable whenever possible. of its constituent unigrams (blue and orange bars).It is clear that the bigram coefficient is not the simple naive sum of the unigram coefficients (dashed black bar), and the learned coefficients make intuitive sense, suggesting that this Aug-GAM model has successfully learned interactions.
2.0 1.5 1.0 0.5 0.0 0.   B4 explores different variations of Aug-Tree.The top row shows learning a single tree with Aug-Tree using its default parameters, achieving the best performance across the datasets.Table B4 shows results for different algorithmic choices, such as replacing the generic prompt with a dataset-specific one (Aug-Tree (Contextual prompt)), and searching for new keyphrases using 5 CART features instead of one (Aug-Tree (5 CART features)).We also consider preprocessing the data differently, using Stemming (with the Porter Stemmer) or using Trigrams, rather than bigrams.
One major variation we study is using LLM embeddings to find keyphrases, rather than querying via a prompt (Aug-Tree (Embeddings).Specifically, we consider expanding keywords by finding the keyphrases that are closest in embedding space (measured by euclidean distance) to the original keyphrase.This option may be desirable computationally, as it may require a smaller LLM to compute effective embeddings (e.g.BERT [3]) compared to a larger LLM required to directly generate relevant keyphrases (e.g.GPT3 [1]).However, finding closest embeddings requires making more calls to the LLM, as embeddings must be calculated and compared across all ngrams in X text .
Algorithm B2 Aug-Tree algorithm for fitting a single split.

Fig. 2 (
Fig.2(A) Test accuracy as a function of ngram size.As the ngram size (i.e. the number of tokens in the ngram) increases, the gap between Aug-GAM and the baselines grows.Averaged over three random cross-validation splits; error bars are standard errors of the mean (many are within the points).(B) Accuracy when using Aug-GAM in combination with BERT.A large percentage of samples can be accurately predicted with Aug-GAM.

Fig. 3
Fig. 3 Test performance as a function of (A) tree depth and (B) number of estimators.Values are averaged over 3 random dataset splits; error bars show the standard error of the mean (many are within the points).

Fig. 4
Fig. 4 Top and bottom contributing ngrams to an Aug-GAM model trained on SST2 bigrams are (A) qualitatively semantically accurate and (B) match human-labeled phrase sentiment scores.For the same Aug-GAM model, which is trained only on bigrams, inferred trigrams coefficients are (C) qualitatively semantically accurate and (D) match human-labeled phrase sentiment scores.
As one example, Fig 4C shows that the Aug-GAM model trained only on bigrams in Fig 4A/B can automatically infer coefficients for trigrams (which were not fit during training).The inferred coefficients are semantically meaningful, even capturing three-way interactions, such as not very amusing.To show a diversity of ngrams, we show every 20th ngram.Fig 4D shows the coefficients compared to the human-labeled SST phrase sentiment for all trigrams in SST.Again, there is a strong correlation, where the Aug-GAM coefficients achieves a rank correlation ρ = 0.71, which even outperforms the bag-of-words model directly trained on trigrams (ρ = 0.49).
Fig 5A visualizes the voxels in the cortex which are better predicted by Aug-GAM than BERT.The improvements are often spatially localized within well-studied brain regions such as auditory cortex (AC).Fig 5B shows that the test performance for Aug-GAM (measured by the Pearson correlation coefficient ρ) outperforms the black-box BERT baseline.Sec C gives further data details and comparisons, e.g.Aug-GAM also outperforms other linear baselines.

Fig. 5
Fig. 5 Aug-imodels prediction performance and interpretation for fMRI voxels.(A) Map of the difference between the performance of Aug-GAM and BERT for fMRI voxel prediction across the cortex.Positive values (red) show where Aug-GAM outperforms BERT (measured by correlation on the test set).(B) Aug-GAM outperforms BERT when averaging across all voxels (or just over the 1%/5% with the highest test correlations).Standard errors of the mean are all less than 0.0015.(C) Example Aug-GAM model for a single voxel (visualized with the top Aug-GAM coefficients).(D) Example Aug-Tree model for two voxels.
Fig A2B it shows the accuracy of the same models in Fig A2A, as the order of ngrams used only for testing is varied.As the number of features used for testing increases, the performance tends to increase but interpretations become more difficult.Fig A3 characterizes the full tradeoff between the number of ngrams used for fitting versus testing for all datasets.Generally, the best performance is achieved when the same number of ngrams is used for training and testing (the diagonal).Performance tends to degrade significantly when fewer ngrams are used for testing than training (lower-left).

Fig. A5
Fig.A5Aug-GAM accurately learns interactions rather than simply summing the contributions of individual unigrams.

Table 1
Overview of datasets studied here.The number of ngrams grows quickly with the size of the ngram.

Table 2
Test accuracy for different models.Aug-GAM yields improvements over interpretable baselines and is competitive with some black-box baselines.Errors show standard error of the mean over 3 random data splits (or 3 different prompts for GPT models).

Table A2
Generalization accuracy varies depending on the model used to extract embeddings.Finetuning the embedding model improves Aug-GAM performance, using a BERT model seems to outperform a DistilBERT model, and the layer used to extract embeddings does not have too large an effect.Top two methods are bolded in each column.
A.1 Test-time tradeoffs between accuracy and interpretability/speedThe ability to effectively generalize to unseen tokens in Fig 4C/D raises the question of whether one can vary the order of ngrams used at test-time, to get a tradeoff between accuracy and interpretability (i.e.how many features are used to make a prediction).Depending on the relative importance of accuracy and interpretability for a given problem, one may select to use a different number of features for testing.Fig A2 suggests that this is feasible.

Table B3
Metadata on keyphrase expansions.Results are averaged over keyphrases found in the 4 text-classification datasets in Table1when fitting a 40-tree bagging ensemble.The LLM is queried for 100 expansion candidates, but due to imperfect LLM generations, only 91.6 candidates are generated on average.After deduplication (converting to lowercase, removing whitespaces, etc.), only 83.3 candidates remain.Screening removes almost all candidates, leaving only 0.8 candidates on average.Table B4Performance (ROC AUC) for variations of Aug-Tree.Values are averaged over 3 random dataset splits; error bars are standard error of the mean (many are within the points).Table B5 Performance (Accuracy) for Aug-Treeand Aug-Tree Ensemble.Values are averaged over 3 random dataset splits; error bars are standard error of the mean (many are within the points).*Emotion and Financial phrasebank results are not directly comparable to Table 2, as they have been modified for binary classification.±0.045 0.818 ±0.014 0.613 ±0.009 0.571 ±0.018 Aug-Tree Ensemble 0.800 ±0.008 0.848 ±0.006 0.619 ±0.004 0.614 ±0.016