1.1. Introduction In Cross-tissue, single-cell stromal atlas identifies shared pathological fibroblast phenotypes in four chronic inflammatory diseases, the authors use a technique called "weighted PCA", together with harmony, to remove batch effect across a wide variety of diseases. The most important observation they made is the stark difference between the number of cells between tissues. I rewrote a python implementation of an R version https://github.com/immunogenomics/singlecellmethods.
The original code lacks a bit of an explanation, however.
To dive into this, I would like to explain the concept of weighted PCA, first by formalizing the idea of weighted expectation, weighted variance, and eventually weighted covariance matrix. While there are various implementation of weighted PCA out there, this is perhaps the easiest implementation, most intuitive, and also well generalized from the original definition of PCA. First, we can define a weighted inner-product in the Euclidean space:
where the diagonal entries of
Then, we can define a weighted mean of a vector
and
In short, most of our measures, i.e. correlation, covariance, mean, variance, are replaced with the weighted version. I think it makes sense that this has to be built from the ground up using a different version of the dot-product. The weights for each observation can also be interpreted as corresponding to the frequency of each observation. In an imbalance situation, it is favorable to incorporate this weight to reflect the frequency of different classes of observation.
Now, for a large matrix
The covariance matrix is, in fact, no longer
Here
and hence
2.Usage
An ideal weights array would sum up to 1, and a cell's weight must be inversely proportional to its population 'size.
For example, if a Batch of anndata is stored in .obs.Batch_key, we can get the frequency by:
freq= anndata.obs[Batch_key].value_counts()
w=1/freq
w=w[anndata.obs[Batch_key]]/length(unique(anndata.obs[Batch_key)))
This can make a nice helper function:
def generate_weights(anndata,Batch_key):
assert Batch_key in anndata.obs.columns
freq= anndata.obs[Batch_key].value_counts()
w=1/freq
w=w[anndata.obs[Batch_key]]/len(unique(anndata.obs[Batch_key]))
return(np.array(w))
and so use can calculate a new w_PCA representation by running
w=generate_weights(anndata,batch)
anndata.obsm["X_wpca"]=weighted_pca(anndata,w,n_comp,corr=True)