62 lines
2.0 KiB
Python
62 lines
2.0 KiB
Python
import numpy as np
|
|
from converted_code.smartmean import smartmean
|
|
|
|
def smartstd(x, dim=None):
|
|
"""
|
|
Python implementation of the MATLAB smartstd function.
|
|
Calculates the standard deviation while ignoring NaN and Inf values.
|
|
Uses N (not N-1) for normalization.
|
|
|
|
Parameters:
|
|
x (ndarray): Input array
|
|
dim (int, optional): Dimension along which to compute the standard deviation
|
|
|
|
Returns:
|
|
ndarray: Standard deviation along the specified dimension, ignoring NaN/Inf
|
|
"""
|
|
if isinstance(x, list):
|
|
x = np.array(x)
|
|
|
|
# If dimension is not specified, find the first non-singleton dimension
|
|
if dim is None:
|
|
# Find the first dimension with size > 1
|
|
if len(x.shape) == 1:
|
|
dim = 0
|
|
else:
|
|
non_singleton_dims = [i for i, size in enumerate(x.shape) if size > 1]
|
|
dim = non_singleton_dims[0] if non_singleton_dims else 0
|
|
|
|
# Compute mean along the specified dimension
|
|
mean_x = smartmean(x, dim)
|
|
|
|
# Create a tile for broadcasting mean back to original dimensions
|
|
tile_shape = list(x.shape)
|
|
tile_shape[dim] = 1
|
|
|
|
# Reshape mean to be broadcastable
|
|
broadcast_shape = [1] * len(x.shape)
|
|
broadcast_shape[dim] = mean_x.shape[0] if dim == 0 else mean_x.shape[1]
|
|
|
|
if dim == 0:
|
|
mean_reshaped = mean_x.reshape(1, -1)
|
|
x_centered = x - np.tile(mean_reshaped, (x.shape[0], 1))
|
|
elif dim == 1:
|
|
mean_reshaped = mean_x.reshape(-1, 1)
|
|
x_centered = x - np.tile(mean_reshaped, (1, x.shape[1]))
|
|
else:
|
|
# For higher dimensions - this might need more complex handling
|
|
raise ValueError("Dimensions higher than 2 not fully supported yet")
|
|
|
|
# Get finite mask
|
|
mask = np.isfinite(x)
|
|
|
|
# Zero out non-finite values in centered data
|
|
x_centered = np.where(mask, x_centered, 0)
|
|
|
|
# Square the differences
|
|
squared_diff = np.square(x_centered)
|
|
|
|
# Compute the mean of squares (using N, not N-1)
|
|
result = np.sqrt(smartmean(squared_diff, dim))
|
|
|
|
return result |