VetTag: improving automated veterinary diagnosis coding via large-scale language modeling

Unlike human medical records, most of the veterinary records are free text without standard diagnosis coding. The lack of systematic coding is a major barrier to the growing interest in leveraging veterinary records for public health and translational research. Recent machine learning effort is limited to predicting 42 top-level diagnosis categories from veterinary notes. Here we develop a large-scale algorithm to automatically predict all 4577 standard veterinary diagnosis codes from free text. We train our algorithm on a curated dataset of over 100 K expert labeled veterinary notes and over one million unlabeled notes. Our algorithm is based on the adapted Transformer architecture and we demonstrate that large-scale language modeling on the unlabeled notes via pretraining and as an auxiliary objective during supervised learning greatly improves performance. We systematically evaluate the performance of the model and several baselines in challenging settings where algorithms trained on one hospital are evaluated in a different hospital with substantial domain shift. In addition, we show that hierarchical training can address severe data imbalances for fine-grained diagnosis with a few training cases, and we provide interpretation for what is learned by the deep network. Our algorithm addresses an important challenge in veterinary medicine, and our model and experiments add insights into the power of unsupervised learning for clinical natural language processing.


Dataset Details
We provide additional descriptive statistics of the dataset below. The training and evaluation CSU dataset, the external evaluation PP dataset and the unsupervised learning PSVG dataset are different due to the nature of the clinics. This additional information allows us to quantify the domain mismatch between CSU, PP and PSVG.

Length Distribution
We plot a histogram to show the proportion of records in each dataset with the certain length in Supplementary Figure 1. Noticeably, CSU and PP follow the SOAP format ("Subjective, Objective, Assessment, Plan"), but PSVG data due to the API limitation, does not strictly follow this format. PSVG notes contain "History, Plan, Physical Exam" sections of the electronic medical record data. This causes the length of PSVG notes to be much shorter compared to CSU and PP notes.

Number of Labels Per Document Distribution
We plot a histogram to show the proportion of records in each labeled dataset with the certain number of labels in Supplementary Figure 1. We do not have any labeled notes from PSVG and hence it is not included in the plot.

Species Distribution
We plot pie charts to show the proportion of species in each labeled dataset in Supplementary Figure 2. CSU dataset contains a fair amount of notes across different species (e.g. equine, bovine, etc.), while PP being a suburban-based private clinic, the animal representations are much more focused on house pets.

Model Details
We formulate the problem of veterinary diagnosis coding as a multi-label classification problem. Given a veterinary record X, which contains detailed description of the diagnosis, we try to infer a subset of diagnoses y ∈ Y, given a pre-defined set of diagnoses Y. The problem of inferring a subset of diagnosis codes can be viewed as a series of independent binary prediction problems 1 . The binary classifier learns to predict whether a diagnosis code y i exists or not for i = 1, ..., m, where m = |Y| = 4577.
Our learning system has two components: a text encoder module and diagnosis code prediction module. In our work, we evaluated three text encoder modules: the convolutional neural network (CNN), the long short-term memory network (LSTM), which has demonstrated its effectiveness in learning implicit language patterns from the text 2 , and the Transformer network, recently developed and proposed by Vaswani et al. 3 . Our diagnosis code prediction module consists of binary classifiers that are parameterized independently for each diagnosis.

Text Encoder
CNN The convolutional neural network (CNN) has been demonstrated to be effective for many NLP tasks 4 . Given a sequence of word embeddings x 1 , ..., x T , we apply a convolution operation with a window size of h (words) and a max-over-time pooling operation to get the summary vector c. The computation can be described in Eq 1, where ⊕ and ⊗ indicate the concatenation operator and the convolution operator, and tanh is the hyperbolic tangent function.
LSTM The long short-term memory network (LSTM) is a recurrent neural network with a long shortterm memory cell 5 . A common LSTM network is composed of a hidden state h t , a cell state c t , an input gate i t , an output gate o t and a forget gate f t . It maintains semantic gating functions specifically designed to capture long-term dependency between words. Given a sequence of word embeddings x 1 , ..., x T , the recurrent computation of LSTM network at a time step t can be described in Eq 2. σ is the sigmoid function: σ(x) = 1/(1 + e −x ), and indicates the Hadamard product.
Transformer The Transformer network was proposed by Vaswani et al. as a machine translation architecture 3 . We use a multi-layer Transformer setup similar to the one in Radford et al. 6 . The Transformer network is defined as a feed-forward network that starts at the word embedding level. Given the word embeddings of a sequence x 1 , ..., x T ∈ R d , we add positional embedding to such sequence so that the model can know the location of each word. We define this positional embedding PE ∈ R T ×d , where T is an arbitrarily set maximum length of a sequence (usually much longer than the longest sequence in our training dataset). For notation convenience, we let it equal to the sequence length T . However, since PE is generated as a cyclical sine-cosine wave and never updated during training, we can easily generate PE for sequence longer than T . For i = 1, ..., d/2, we can define the element in the PE matrix in Eq 3 (symbols inside parentheses indicate the coordinate of the element).
PE(t, 2i) = sin(t/10000 2i/d ) We define the first input to the Transformer network: H 0 = X + PE, where X = {x 1 , ..., x T } and PE is defined above. We note that H 0 ∈ R T ×d . Then for a given layer l, l > 0, we can define a feedforward transformer block in Eq 4. We let W v to have dimensions (d/n) × d, and the resulting H (i) to have dimension T × (d/n). We additionally apply a mask M over the attention so that the model only looks at < t steps when it generates the token at step t. For the same layer l, we repeat the above computation n times. This is referred as the multi-head attention computation, and n indicates the number of heads.
After the multihead attention computation described above, we concatenate n H (i) matrices to obtaiñ H l ∈ R T ×d . We then apply a fully connected layer with ReLU activation function to this matrix and obtain the final hidden representation of the sequence for layer l: H l . We describe the calculation in Eq 5. The matrix multiplication by W o1 ∈ R D×d , W o2 ∈ R d×D are referred to as a bottleneck computation, where D is much larger than d. The Transformer network repeats the above computation to construct a multi-layer Transformer network.H l = Concat(H (1) , H (2) , ..., H (n) )

Diagnosis Code Prediction
We define a binary classifier for each of the 4577 diagnosis code in our pre-defined set. The binary classifier takes in a summary vector c that represents the veterinary record and outputs a sufficient statistic for the Bernoulli probability distribution indicating the probability of whether a diagnosis should be predicted (Eq 6).
Flat Training We use binary cross entropy loss averaged across all labels as the training loss for the flat training. Given the binary predictions from the modelŷ ∈ [0, 1] m and correct binary label y ∈ {0, 1} m , binary cross entropy loss is written in Eq 7. The decision boundary in our model is set to be 0.5.
Hierarchical Training In this setting, we first define a adjacency matrix M , where M ij = 1 if diagnostic code i is the child of diagnostic code j in the SNOMED-CT hierarchy, otherwise M ij = 0. during training time, we generate a mask b ∈ [0, 1] m . We generate the mask based on a recursive definition (Eq 8).
b i = 1, y j = 1 and M ij = 1 0, otherwise Then we can easily apply this mask to both y the ground truth label as well as theŷ. This masking vector allows us to only penalize the prediction of a diagnostic code when its parent is present, thus greatly reducing the number of negative examples for rare diagnoses. We compute the hierarchical binary cross-entropy loss in Eq 9.
During inference time, we generate masking vectorb by settingb i = 1 when all the ancestors of the diagnosis code i are predicted as true, and produce the final model prediction asŷ ·b.

Experimental Setup
We describe our experimental setup in the following section. We truncate all documents to no more than 600 tokens, padded with start and end of sentence tokens. This step is helpful in reducing computational requirement.
Neural Network Architecture In order to have a fair comparison among encoders-CNN, LSTM and Transformer-we set all the latent dimension as 768. For the CNN, we use 384 kernels for convolution with the kernel size of 4. For the LSTM, we compare the performance of unidirectional LSTM and bidirectional LSTM. For the Transformer, we stack 6 transformer blocks, with 8 heads for the multi-head attention on each layer. We let the feedforward dimension to be 2048.
Pretraining All pretraining is conducted on the PSVG dataset. We investigate the effect of pretraining the word embedding (+W) and pretraining the encoder with language modeling objective (+P). In the word embedding pretraining, we use the Word2Vec algorithm 7 on the PSVG dataset. The word embedding dimension is set to 768. For the pretraining language modeling objective, we initialize word embeddings with Xavier initialization 8 and directly optimize − log P (X).
Training We implement our model in PyTorch. We use Noam Optimizer 3 with 8000 warm up steps. The dropout rate is set to 0.1 during training to reduce overfitting. All models are trained for 10 epochs. We use a batch size of 5 for each model, which is the maximum allowed to train VetTag on a single GPU.
MetaMap Baseline We use the popular MetaMap, a program developed by the National Library of Medicine (NLM) 9 , as a baseline. MetaMap processes a document and outputs a list of matched medicallyrelevant keywords with its frequencies in the given document. We use MetaMap as a feature extractor, mapping each document into a frequency-encoded bag-of-words vector. The final feature vector size is 57,235. We perform the multi-label classification task with SVM and MLP with feature vectors.

Comparing VetTag and DeepTag
DeepTag is designed to make predictions on the 42 top-level diagnosis categories 10 . We restrict the performance of VetTag to these top-level categories except for clinical finding (the spurious category) in order to directly compare its performance head-to-head with DeepTag. Note that VetTag is optimized not to predict just on these categories but on all 4577 categories; hence the comparison is more favorable for DeepTag. We report the result comparison in Supplementary Table 1. On the PP test data, VetTag substantially outperforms DeepTag for both F 1 and exact match (EM). On the CSU test data, VetTag achieved better EM score and comparable F 1 score as DeepTag. Supplementary Figure 3 provides the comparison of VetTag and DeepTag for the 20 most frequent categories, demonstrating the superior performance of VetTag.

Interpretation Details
We compute the standard saliency map for each input text; this is defined as the input vector multiplied by the gradient of the predicted probability with respect to the input. The saliency of each word quantifies the influence of that word on VetTag's predictions. For each of the 41 top-level diagnosis categories, we select the top 50 words that have the highest saliency for that diagnosis, defined as the words with saliency score ≥ 0.2 are the largest number of clinical notes labeled with the diagnosis. We then intersect the 50 most salient words with the MetaMap expert-curated dictionary in order to select the most medically relevant words. These words are shown in Supplementary Table 2.