Efficient Transfer Learning with Sequential and Multi-Modal Approaches for Electronic Health Records
Yogesh Kumar
25th October 2024
Outline
- Forecasting Healthcare Utilization
- Publication I
- Publication II
- Comparing two neural networks
- Multi-modal Contrastive Learning
Part 1
Forecasting Healthcare Utilization
Structured EHR
Tabular information stored in EHR databases are actually temporal when grouped-by individuals
Predicting Healthcare Utilization
- Task: Predict the number of physical visits to a health center in the next year
- Cohort: Senior citizens - 65 years or older
Predicting Healthcare Utilization
Task: predict the number of physical visits to a health center in the next year
Why is this important?
- This could help with risk adjustment while allocating resources
- Patients with chronic conditions might need more visits
- Hospitals specializing in serious illnesses tend to have more visits -- doesn't mean they are ineffective
Research Q1: Can integration of temporal dynamics of EHR data improve accuracy and efficiency?
Research Q1: Can integration of temporal dynamics of EHR data improve accuracy and efficiency?
How useful is temporal information?
Divergent subgroup
Different subgroups of the general population have different target distributions
- We can’t use a single model for all subgroups
- We can’t build individual models for each subgroup (e.g. N = 100 patients)
Self-supervision in EHR
BEHRT: Transformer for Electronic Health Records, Li et al. 2020
Self-supervision in EHR
EHR data is not similar to natural language
Research Q2: Do models custom-built for EHR outperform generic NLP models?
Predicting utilization on divergent subgroups
Research Q2: Do models custom-built for EHR outperform generic NLP models?
- Task
- Predict physical number of visits
- Predict visit count for six major disease categories
- Evaluation on specific patient subgroups
- Type 2 diabetes (N ~ 40,000)
- Bipolar disorder (N ~ 1000)
- Multiple sclerosis (N ~ 100)
A bespoke model for EHR
A bespoke model for EHR
Replace self-attention with mixers
Pay attention to MLPs, Liu et al. 2021
Evaluating sample efficiency
SANSformer pretrained on general population shows better sample efficiency
Part 2
Comparing two neural networks
Representational similarity
Pairwise similarity yields a symmetric representational similarity matrix (RSM)
Representational similarity
We can compare the similarity between RSMs of two networks
Representational similarity
Similarity measure between two networks (A & B) in two steps
- Compute pairwise similarity of each sample representation to get similarity matrices:
\( S_A = \text{sim}_k(X_A, X_A) \)
\( S_B = \text{sim}_k(X_B, X_B) \)
- Compute similarity between the similarity matrices:
\( S_{A \leftrightarrow B} = \text{sim}_s(S_A, S_B) \)
Measuring domain similarity
We would like to quantify the similarity between two domains
A small discriminator tries to predict the domain for a given sample
The model will find it harder to classify samples from similar domains
Measuring domain similarity
The BCE loss quantifies the domain similarity
English is more similar to Spanish than say, Russian.
Hence the BCE loss is higher for English-Spanish than English-Russian
CKA and domain similarity
CKA when seen block-wise, doesn't correlate well with domain similarity
CKA between PT (English) - FT (Spanish) is lower than PT (English) - FT (Russian)
This could be due to a bias from input similarity
Research Q3: Can we improve consistency of representation similarity by adjusting for the bias caused by the input structure?
Adjusting for input similarity
Research Q3: Can we improve consistency of representation similarity by adjusting for the bias caused by the input structure?
Given a dataset $X$ and a network $A$
\( S_X = \text{sim}_k(X, X) \)
\( S_A = \text{sim}_k(X_A, X_A)\)
We fit a linear model on input similarity and obtain the residual $\epsilon_A$
\( S_A = \alpha S_X + \epsilon_A \)
The residual $\epsilon_A$ is the adjusted similarity matrix
\( \epsilon_A = S_A - \alpha S_X \)
Similarity is then computed on the adjusted similarity matrices for networks A and B
\( S_{A \leftrightarrow B} = sim_s(\epsilon_A, \epsilon_B) \)
Adjusted CKA
Representational similarity of networks finetuned on similar domains should be higher
and vice versa
dCKA correlates better with domain similarity
Part 3
Multi-modal Contrastive Learning
Contrastive Language Image Pretraining (CLIP)
maximize cosine similarity of positive pairs
minimize it for negative pairs
Medical CLIP
CLIP does weak supervision
Local regions of image and text are not closely aligned
Medical CLIP
We could utilize heatmaps obtained from radiologist eye-tracking data to improve alignment
Research Q4: Can high-quality expert annotations improve medical CLIP performance?
Expert annotated CLIP (eCLIP)
A heatmap processor applies MHA between heatmap and original image to create expert image $I_i^E$
CLIP loss function
InfoNCE loss for contrastive learning
CLIP utilizes one pair per image-text $\rightarrow$ ($v_i$, $t_i$)
CLIP loss function
InfoNCE loss for contrastive learning
We get additional pairs from the expert image
Similar to Data Augmentation, maybe?
Modality gap in CLIP embeddings
UMAP projection reveals that CLIP embeddings form distinct clusters that pertain to their respective modalities (e.g., images and text)
Expert annotated CLIP (eCLIP)
We create an additional modality, expert image $I_i^E$, by utilizing the modality gap
Sample efficiency
eCLIP creates more embedding pairs per X-ray, improving sample efficiency
Conclusions
- Sample efficiency
- Having the right inductive biases for the model
- Pretraining on a larger, related dataset
- Comparing two neural networks
- Adjusting for bias from input similarity improves functional consistency
- Multi-modal Contrastive Learning
- Expert annotations as an additional modality improves contrastive learning
Thank you! ❤️️
Say hi:
(for a few more months...) yogesh.kumar@aalto.fi
ykumar@nyu.edu
ykumards.com