Source code for HieDiff._HieDiff_minibatch

# -*- coding: utf-8 -*-

from typing import Optional, Tuple
from anndata import AnnData

import torch
import numpy as np
import pandas as pd
from scipy.sparse import issparse, csr_matrix

from tqdm import trange

from ._moduleW import HieDiff
from ._model_utils import one_hot


def _run_HieDiff_minibatch(
    X: np.ndarray,
    n_epochs: int,
    n_hidden: int,
    n_latent: int,
    n_batch: int,
    batch_index: Optional[np.ndarray],
    n_covar: int,
    covar: Optional[np.ndarray],
    device: Optional[str],
) -> Tuple[np.ndarray, np.ndarray]:
    
    if device is None or device == 'cuda':
        if torch.cuda.is_available():
          device = 'cuda'
        else:
          device = 'cpu'
    
    device = torch.device(device)
    
    data_X = torch.Tensor(X).to(device)
    if batch_index is not None:
        batch_index = one_hot(torch.Tensor(batch_index).to(device), n_batch)
    
    if covar is not None:
        covar = torch.Tensor(covar).reshape((-1, n_covar)).to(device)
        if batch_index is not None:
            covar = torch.concat((covar, batch_index), axis=0)
    else:
        covar = batch_index
    
    n_obs = data_X.shape[0]
    batch_size = 3000
    num_batches = n_obs // batch_size
    batch_size_extra = n_obs % batch_size
    
    model = HieDiff(
        n_input=data_X.shape[1],
        n_covar=n_covar+n_batch,
        n_hidden=n_hidden,
        n_latent=n_latent,
    ).to(device)
    
    model.train(mode=True)
    
    params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.Adam(params, lr=1e-3, eps=0.01, weight_decay=1e-6)
    
    pbar = trange(n_epochs)
    
    for epoch in pbar:
        
        n_perm = np.random.permutation(range(n_obs))
        
        loss_epoch = 0
        
        for batch_i in range(num_batches+1):
            
            if batch_i < num_batches:
                batch_idx = n_perm[(batch_i * batch_size):((batch_i+1) * batch_size)]
            elif batch_size_extra:
                batch_idx = n_perm[-batch_size:]
            else:
                break
            
            optimizer.zero_grad()
            
            covar_tmp = covar[batch_idx,:] if covar is not None else None
            inference_outputs = model.inference(data_X[batch_idx,:])
            generative_outputs = model.generative(inference_outputs['z'], covar_tmp)
            W_outputs = model.forward_W(data_X[batch_idx,:], covar_tmp)
            
            loss = model.loss(data_X[batch_idx,:], inference_outputs, generative_outputs, W_outputs, epoch/n_epochs)
            
            loss_epoch += loss.item()
            
            loss.backward()
            optimizer.step()
        
        pbar.set_postfix_str(f'loss: {loss_epoch:.3e}')
    
    model.eval()
    
    with torch.no_grad():
        inference_outputs = model.inference(data_X)
        generative_outputs = model.generative(inference_outputs['z'], covar)
        qz = inference_outputs['qz'].loc.detach().cpu().numpy()
        x4 = generative_outputs['x4'].detach().cpu().numpy()
        W = model.layerW.getW()
    
    return qz, x4, W


[docs]def run_HieDiff_minibatch( adata: AnnData, n_epochs: int = 1000, n_hidden: int = 128, n_latent: int = 10, batch_key: Optional[str] = None, covar_key: Optional[str] = None, device: Optional[str] = None, copy: bool = False, ) -> Optional[AnnData]: ''' The mini-batch implementation for hierarchical flow diffusion. Parameters ---------- adata Annotated data matrix. n_epochs Number of epochs for training neural network. Default to 1000. n_hidden Number of neurons in the hidden layer. Default to 128. n_latent Number of neurons in the latent layer. Default to 10. params_dict The pretrained parameters for initialing the neural network. If not specified, the parameters in the neural network is randomly initialized. batch_key The key to retriving batch information in `adata.obs[batch_key]`. If not specified, the batch correction is not considered. covar_key The key to retriving covariates in `adata.obsm[covar_key]`. If not specified, the covariates is not considered. device The desired device for `PyTorch` computation. By default uses cuda if cuda is avaliable cpu otherwise. copy Return a copy instead of writing to ``adata``. Returns ------- Depending on ``copy``, returns or updates ``adata`` with the following fields. .obsm['qz'] : :class:`~numpy.ndarray` The latent representation of gene expression. .varp['W'] : :class:`~scipy.sparse.csr_matrix` The gene-by-gene relation matrix. .layers['x4'] : :class:`~numpy.ndarray` The denoised gene expression matrix. ''' adata = adata.copy() if copy else adata if batch_key is not None: batch_info = pd.Categorical(adata.obs[batch_key]) n_batch = batch_info.categories.shape[0] batch_index = batch_info.codes.copy() else: n_batch = 0 batch_index = None if covar_key is not None: if covar_key in adata.obs.keys(): covar = adata.obs[covar_key].to_numpy() n_covar = 1 elif covar_key in adata.obsm.keys(): covar = np.array(adata.obsm[covar_key]) n_covar = covar.shape[1] else: n_covar = 0 covar = None qz, x4, W = _run_HieDiff_minibatch( X=adata.X.toarray() if issparse(adata.X) else adata.X, n_epochs=n_epochs, n_hidden=n_hidden, n_latent=n_latent, n_batch=n_batch, batch_index=batch_index, n_covar=n_covar, covar=covar, device=device, ) key_added = 'HieDiff' qz_key = 'qz' x4_key = 'x4' adata.uns[key_added] = {} neighbors_dict = adata.uns[key_added] neighbors_dict['params'] = {} neighbors_dict['params']['method'] = 'umap' adata.obsm[qz_key] = qz adata.layers[x4_key] = csr_matrix(x4) adata.uns['W'] = {} neighbors_var_dict = adata.uns['W'] neighbors_var_dict['connectivities_key'] = 'W' neighbors_var_dict['distances_key'] = 'W' adata.varp['W'] = csr_matrix(W) return adata if copy else None