Source code for sift._sift

import logging
import os
from types import MappingProxyType
from typing import Any, Dict, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
from anndata import AnnData
from scipy.sparse import issparse

import sift._backend_utils as bu
import sift._types as st
from sift._constants import BATCH_SIZE, MAX_SRC_SIZE, Backend, KernelType, Stage
from sift._utils import (
    _get_embedding,
    _get_mask_idx,
    _prepare_kernel_container,
    _save_res,
)
from sift._utils import plot_kernel as _plot_kernel
from sift.kernels import Kernel

logger = logging.getLogger(__name__)


__all__ = ["SiFT", "sifter"]


[docs]class SiFT: """SiFT object. Parameters ---------- adata Annotated data object. kernel_key Key in :class:`anndata.AnnData` where information for distance is stored. metric The kernel type. src_key Mask for the source space of the kernel. tgt_key Mask for the target space of the kernel. knn_key wey in :attr:`anndata.AnnData.obs` for k-NN masking. save: bool Whether to save the kernel. kernel_params Kernel object parameters [length_scale]. kwargs Keyword arguments for kernel computations [precomputed_kernel, n_neighbors, knn_batch_key]. Examples -------- >>> import scanpy as sc >>> import sift >>> adata = sc.read(...) >>> sft = sift.SiFT( adata=adata, kernel_key="cell_cycle_genes", metric="RBF", ) >>> sft.filter() """ def __init__( self, adata: AnnData, kernel_key: Optional[str] = None, metric: st.Kernel_t = KernelType.RBF, mask_key: Optional[str] = None, src_key: Optional[str] = None, tgt_key: Optional[str] = None, copy: bool = False, kernel_params: Optional[Dict[str, Any]] = MappingProxyType({}), **kwargs: Any, ): metric = KernelType(metric) self._adata = adata.copy() if copy else adata self._kernel_key = kernel_key self._kernel_metric = metric self._src_key = src_key self._tgt_key = tgt_key self._stage = Stage.INITIALIZED self._kernel = None self._kernel_mask = None self._mask_key = None self._kernel_containers = {} self._copy = copy backend = bu.get_backend(None) bu.assert_backend_available(backend) device = bu.get_device(backend) dtype = bu.get_dtype(backend) self._device = kwargs.pop("device", device) self._dtype = kwargs.pop("dtype", dtype) self._backend = kwargs.pop("backend", backend) self._batch_size = kwargs.pop("batch_size", BATCH_SIZE) self._max_src_size = kwargs.pop("max_src_size", MAX_SRC_SIZE) self._init_kernels( kernel_key=kernel_key, mask_key=mask_key, metric=metric, kernel_params=kernel_params, **kwargs, ) self._model_summary_string = f"SiFT object with {metric} kernel" logger.info(f"initialized a SiFTer with {metric} kernel.") if self._mask_key is not None: logger.info(f"SiFTer has a mask_kernel of type {self._mask_key}") @property def adata(self) -> AnnData: """The SiFT annotated data object.""" return self._adata @adata.setter def adata(self, adata: AnnData) -> None: self._adata = adata @property def kernel(self) -> st.KernelLike: """The :mod:`pykeops` kernel object.""" return self._kernel(**self._kernel_containers) @kernel.setter def kernel(self, k: st.KernelLike) -> None: self._kernel = k def _init_kernels( self, kernel_key: Optional[str] = None, mask_key: Optional[str] = None, metric: st.Kernel_t = KernelType.RBF, precomputed_kernel: Optional[st.ArrayLike] = None, kernel_params: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Sets the cell-cell similarity kernel. Parameters ---------- kernel_key Key in :class:`anndata.AnnData` where information for distance is stored. mask_key Key in :attr:`anndata.AnnData.obs` for k-NN masking. metric The kernel type. precomputed_kernel The user can provide a pre-computed kernel, it is assumed cells are ordered as in :attr:`adata`. kernel_params Kernel object parameters [length_scale]. kwargs Keyword arguments for kernel computations [precomputed_kernel, n_neighbors, knn_batch_key]. Returns ------- Nothing, just initializes the kernels. """ ( self._kernel_containers["x"], self._kernel_containers["y"], ) = _prepare_kernel_container( adata=self.adata, kernel_key=kernel_key, metric=metric, precomputed_kernel=precomputed_kernel, src_key=self._src_key, tgt_key=self._tgt_key, **kwargs, ) self._kernel = Kernel._create( metric, device=self._device, dtype=self._dtype, backend=self._backend, **kernel_params, ) if mask_key is not None: x, y = _prepare_kernel_container( adata=self.adata, kernel_key=mask_key, metric=KernelType.KNN, src_key=self._src_key, tgt_key=self._tgt_key, uns_key="mask_kernel", **kwargs, ) _kernel_mask = Kernel._create( KernelType.MAPPING, device=self._device, dtype=self._dtype, backend=self._backend, ) _kernel_mask(x=x, y=y) self._kernel_mask = ( _kernel_mask.k.A if issparse(_kernel_mask.k) else _kernel_mask.k ) self._mask_key = mask_key def _materialize_kernel( self, k: Optional[st.ArrayLike] = None, to_cpu: Optional[bool] = True, ) -> st.KernelLike: """Materialize the initialized kernel. Parameters ---------- k Kernel to materialize, if `None` call `self._kernel.kernel()`. to_cpu: bool Whether to move kernel to CPU. Returns ------- The materialized kernel object. Notes ----- This operation can result in out-of-memory (OOM) error. """ k = ( k if k is not None else self._kernel(materialize=True, **self._kernel_containers) ) k = k.cpu().numpy() if to_cpu and bu.is_torch_tensor(k) else k k = k.A if issparse(k) else k return k
[docs] def materialize_kernel( self, kernel_containers: Optional[Dict] = None, to_cpu: bool = True ) -> st.KernelLike: """Materialize the initialized kernel. Parameters ---------- to_cpu Whether to move kernel to CPU. kernel_containers Keyword arguments to the kernel call. Returns ------- The materialized kernel after masking, if :attr:`._kernel_mask` is present. Notes ----- This operation can result in out-of-memory (OOM) error. """ kernel_containers = ( kernel_containers if kernel_containers is not None else self._kernel_containers ) if self._mask_key is not None: return self._materialize_kernel(self._mask_kernel(), to_cpu=to_cpu) else: return self._materialize_kernel( self._kernel(materialize=True, **kernel_containers), to_cpu=to_cpu )
def _mask_kernel( self, kernel_input_idx: Optional[Dict[str, Any]] = None ) -> st.ArrayLike: """Mask the kernel. Parameters ---------- kernel_input_idx Dictionary with index for kernel input. Returns ------- The materialized kernel after masking, if :attr:`_kernel_mask` is present. Notes ----- This operation can result in out-of-memory (OOM) error. """ if kernel_input_idx is not None: x = self._kernel_containers["x"][kernel_input_idx["x"], :] y = self._kernel_containers["y"][kernel_input_idx["y"], :] kernel_mask = ( self._kernel_mask[kernel_input_idx["x"], kernel_input_idx["y"]] if self._mask_key else None ) else: x = self._kernel_containers["x"] y = self._kernel_containers["y"] kernel_mask = self._kernel_mask if self._mask_key else None if self._mask_key is not None: if self._backend == Backend.TORCH and not issparse(self.kernel): k = ( self._materialize_kernel( k=self._kernel(x=x, y=y, materialize=True), to_cpu=False ) * kernel_mask ) elif issparse(self.kernel) and self._mask_key is not None: k = self.kernel.multiply(kernel_mask) else: k = ( self._materialize_kernel( k=self._kernel(materialize=True, **self._kernel_containers) ) * kernel_mask ) else: k = self._kernel(x=x, y=y) return k
[docs] def filter( self, embedding_key: Optional[str] = None, use_raw: Optional[bool] = False, pseudocount: bool = False, key_added: Optional[str] = None, ) -> Optional[AnnData]: """Perform filtering, subtract the projection of the kernel from the expression. Parameters ---------- embedding_key Key in :attr:`anndata.AnnData.obsm` for which to compute the projection on. If `None`, use :attr:`anndata.AnnData.X`. use_raw Whether to use :attr:`anndata.AnnData.raw`, if present. pseudocount If `True`, add a pseudocount to filtered expression to avoid negative values key_added Key used when saving the result. If `None`, use ``'{embedding_key}_sift'``. The result is saved to :attr:`anndata.AnnData.X`, :attr:`anndata.AnnData.layers` or :attr:`anndata.AnnData.obsm`, depending on the ``embedding_key``. Returns ------- The filtered data matrix. """ if use_raw is True and self.adata.raw is None: raise ValueError("Received `use_raw=True`, but `adata.raw` is empty.") if use_raw and self.adata.raw is not None: logger.info("Using `adata.raw`") self.adata = self.adata.raw.to_adata() emb, type_emb, embedding_key = _get_embedding( adata=self.adata, embedding_key=embedding_key ) tgt_idx = ( _get_mask_idx(adata=self.adata, key=self._tgt_key) if self._tgt_key is not None else np.arange(self.adata.n_obs) ) src_idx = ( _get_mask_idx(adata=self.adata, key=self._src_key) if self._src_key is not None else np.arange(self.adata.n_obs) ) logger.info( f"Filtering cell-cell similarity kernel using projection on `{embedding_key}`." ) if self._backend == Backend.TORCH: import torch if issparse(self.kernel): k = self._mask_kernel() emb_sifted = emb[src_idx, :] - k @ emb[tgt_idx, :] else: emb_src = emb[src_idx, :] if src_idx.shape[0] > self._max_src_size: batches = np.array_split( np.arange(src_idx.shape[0]), np.ceil(src_idx.shape[0] / self._batch_size), ) logger.info(f"Evaluating with {len(batches)} batches.") y_idx = np.arange(tgt_idx.shape[0]) emb_sifted_batches = [] emb_tgt = torch.tensor( emb[tgt_idx, :].A if issparse(emb) else emb[tgt_idx, :], dtype=self._dtype, device=self._device, ) for batch in batches: k = self._mask_kernel(kernel_input_idx={"x": batch, "y": y_idx}) emb_src_batch = torch.tensor( emb_src[batch, :].A if issparse(emb) else emb_src[batch, :], dtype=self._dtype, device=self._device, ) emb_sifted_batch = emb_src_batch - k @ emb_tgt emb_sifted_batches.append(emb_sifted_batch.cpu().numpy()) emb_sifted = np.concatenate(emb_sifted_batches) else: k = self._mask_kernel() emb = torch.tensor( emb.A if issparse(emb) else emb, dtype=self._dtype, device=self._device, ) emb_sifted = emb[src_idx, :] - k @ emb[tgt_idx, :] else: # numpy backend k = self._mask_kernel() emb = ( emb.astype(dtype=self.kernel.dtype).A if issparse(emb) else emb.astype(dtype=self.kernel.dtype) ) if issparse(self.kernel): emb_sifted = emb[src_idx, :] - k @ emb[tgt_idx, :] else: emb_sifted = emb[src_idx, :] - k.__matmul__( emb[tgt_idx, :], backend="CPU" ) if pseudocount: emb_sifted -= emb_sifted.min() if bu.is_torch_tensor(emb_sifted): emb_sifted = emb_sifted.cpu().numpy() emb_sifted_full = np.zeros((self.adata.n_obs, emb_sifted.shape[-1])) emb_sifted_full[src_idx, :] = ( emb_sifted.A if issparse(emb_sifted) else emb_sifted ) key_added, attr = _save_res( adata=self.adata, embedding=emb_sifted_full, type_emb=type_emb, type_res="sift", embedding_key=embedding_key, key_added=key_added, ) self._stage = Stage.SIFTED if type_emb == "X": logger.info( "The data is `SiFTed`!\n" "The filtered embedding is stored in `adata.X`\n" " Finish" ) else: logger.info( f"The data is `SiFTed`!\n" f"The filtered embedding is stored in `adata.{attr}[{key_added!r}]`\n" f" Finish" ) if self._copy: return self.adata return
[docs] def plot_kernel( self, save_path: Optional[os.PathLike] = None, show: bool = True, **kwargs: Any ) -> Optional[plt.Axes]: """Visualize the cell-cell similarity kernel. Parameters ---------- save_path Path where to save the figure. show If `False`, return :class:`matplotlib.axes.Axes`. kwargs Additional keyword arguments to :func:`sift._utils.plot_kernel()`. Returns ------- The axes object, if ``show = False``, otherwise `None`. """ kernel_containers = kwargs.pop("kernel_containers", self._kernel_containers) groupby = kwargs.pop("groupby", None) if self._kernel_metric == KernelType.MAPPING: groupby = self._kernel_key return _plot_kernel( kernel=self.materialize_kernel( kernel_containers=kernel_containers, to_cpu=True ), adata=self.adata, src_key=self._src_key, tgt_key=self._tgt_key, groupby=groupby, save_path=save_path, show=show, **kwargs, )
[docs]def sifter( adata: AnnData, kernel_key: Optional[str] = None, metric: st.Kernel_t = KernelType.RBF, src_key: Optional[str] = None, tgt_key: Optional[str] = None, knn_key: Optional[str] = None, batch_key: Optional[str] = None, embedding_key: Optional[str] = None, use_raw: Optional[bool] = False, pseudocount: bool = False, key_added: Optional[str] = None, copy: bool = False, **kwargs: Any, ) -> AnnData: """Perform filtering, subtract the projection of the kernel from the expression. Parameters ---------- adata Annotated data object. kernel_key Key in :class:`anndata.AnnData` where information for distance is stored. metric Metric to use. src_key Mask for the source space of the kernel. tgt_key Mask for the target space of the kernel. knn_key Key in :attr:`anndata.AnnData.obs` for k-NN masking. batch_key Key in :attr:`anndata.AnnData.obs` for batch masking. embedding_key Key in :attr:`anndata.AnnData.obsm` for which to compute the projection on. If `None`, use :attr:`anndata.AnnData.X`. use_raw Whether to use :attr:`anndata.AnnData.raw`, if present. pseudocount If `True`, add a pseudocount to filtered expression to avoid negative values. key_added Key used when saving the result. If `None`, use ``'{embedding_key}_sift'``. The result is saved to :attr:`anndata.AnnData.X`, :attr:`anndata.AnnData.layers` or :attr:`anndata.AnnData.obsm`, depending on the ``embedding_key``. copy Whether to modify ``adata`` inplace. kwargs Additional keyword arguments [n_neighbors, precomputed_kernel]. Returns ------- The filtered data matrix. """ return SiFT( adata=adata, kernel_key=kernel_key, metric=metric, knn_key=knn_key, batch_key=batch_key, src_key=src_key, tgt_key=tgt_key, copy=copy, **kwargs, ).filter( embedding_key=embedding_key, use_raw=use_raw, pseudocount=pseudocount, key_added=key_added, )