Source code for ap_features.filters

import logging
from enum import Enum
from typing import Any
from typing import Dict
from typing import List
from typing import Sequence
from typing import Tuple
from typing import Union

import numpy as np

from . import utils
from .utils import Array

logger = logging.getLogger(__name__)


[docs] class InvalidFilter(RuntimeError): pass
[docs] class Filters(str, Enum): apd30 = "apd30" apd50 = "apd50" apd70 = "apd70" apd80 = "apd80" apd90 = "apd90" length = "length" time_to_peak = "time_to_peak"
[docs] def filter_signals( data: Union[Sequence[Array], Dict[Any, Array]], x: float = 1, center="mean", ) -> Sequence[int]: if len(data) == 0: return [] values = data if isinstance(data, dict): values = list(data.values()) # Check that all arrays have the same length v0 = values[0] N = len(v0) for v in values: if len(v) != N: raise RuntimeError("Unequal length of arrays") all_indices = [] for v in values: indices = within_x_std(v, x, center) # If no indices are with the tolerance then we just # include everything all_indices.append(indices if len(indices) > 0 else np.arange(N)) return utils.intersection(all_indices)
[docs] def within_x_std(arr: Array, x: float = 1.0, center="mean") -> Sequence[int]: """Get the indices in the array that are within x standard deviations from the center value Parameters ---------- arr : Array The array with values x : float, optional Number of standard deviations, by default 1.0 center : str, optional Center value, Either "mean" or "median", by default "mean" Returns ------- Sequence[int] Indices of the values that are within x standard deviations from the center value """ if len(arr) == 0: return [] msg = f"Expected 'center' to be 'mean' or 'median', got {center}" assert center in ["mean", "median"], msg mu = np.mean(arr) if center == "mean" else np.median(arr) std = np.std(arr) within = [abs(a - mu) <= x * std for a in arr] return np.where(within)[0]
[docs] def filt(y: Array, kernel_size: int = 3): """ Filer signal using a median filter. Default kernel_size is 3 """ logger.debug("Filter image") from scipy.signal import medfilt smooth_trace = medfilt(y, kernel_size) return smooth_trace
[docs] def remove_points( x: Array, y: Array, t_start: float, t_end: float, normalize: bool = True, ) -> Tuple[Array, Array]: """ Remove points in x and y between start and end. Also make sure that the new x starts a zero if normalize = True """ if not len(x) == len(y): raise ValueError( f"Expected x and y to have same length, got len(x) = {len(x)} and len(y) = {len(y)}", ) start = next(i for i, t in enumerate(x) if t > t_start) - 1 try: end = next(i for i, t in enumerate(x) if t > t_end) except StopIteration: end = len(x) - 1 logger.debug( ("Remove points for t={} (index:{}) to t={} (index:{})" "").format( t_start, start, t_end, end, ), ) x0 = x[:start] x1 = np.subtract(x[end:], x[end] - x[start]) x_new = np.concatenate((x0, x1)) if normalize: x_new -= x_new[0] y_new = np.concatenate((y[:start], y[end:])) return x_new, y_new
[docs] def find_spike_points(pacing, spike_duration: int = 7) -> List[int]: """ Remove spikes from signal due to pacing. We assume that there is a spike starting at the time of pacing and the spike disappears after some duration. These points will be deleted from the signal Parameters ---------- pacing : array The pacing amplitude of same length as y spike_duration: int Duration of the spikes Returns ------- np.ndarray A list of indices containing the spikes """ if spike_duration == 0: return [] # Find time of pacing (inds,) = np.where(np.diff(np.array(pacing, dtype=float)) > 0) if len(inds) == 0: logger.warning("No pacing found. Spike removal not possible.") return [] spike_points = np.concatenate( [np.arange(i, i + spike_duration) for i in inds], ).astype(int) return spike_points