46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
import numpy as np
|
|
|
|
def smartsum(x, dim=None):
|
|
"""
|
|
Python implementation of the MATLAB smartsum function.
|
|
Calculates the sum while ignoring NaN and Inf values.
|
|
|
|
Parameters:
|
|
x (ndarray): Input array
|
|
dim (int, optional): Dimension along which to compute the sum
|
|
|
|
Returns:
|
|
ndarray: Sum values along the specified dimension, ignoring NaN/Inf
|
|
"""
|
|
if isinstance(x, list):
|
|
x = np.array(x)
|
|
|
|
# If dimension is not specified, find the first non-singleton dimension
|
|
if dim is None:
|
|
# Find the first dimension with size > 1
|
|
if len(x.shape) == 1:
|
|
dim = 0
|
|
else:
|
|
non_singleton_dims = [i for i, size in enumerate(x.shape) if size > 1]
|
|
dim = non_singleton_dims[0] if non_singleton_dims else 0
|
|
|
|
# Create a mask of finite values
|
|
mask = np.isfinite(x)
|
|
|
|
# Replace non-finite values with zeros
|
|
x_clean = np.where(mask, x, 0)
|
|
|
|
# Sum the finite values
|
|
result = np.sum(x_clean, axis=dim)
|
|
|
|
# Set result to NaN where all values are non-finite
|
|
if dim == 0:
|
|
result[np.sum(mask, axis=dim) == 0] = np.nan
|
|
elif dim == 1:
|
|
result[np.sum(mask, axis=dim) == 0] = np.nan
|
|
else:
|
|
all_nan_indices = np.where(np.sum(mask, axis=dim) == 0)
|
|
if all_nan_indices[0].size > 0:
|
|
result[all_nan_indices] = np.nan
|
|
|
|
return result |