2025-06-05 08:48:33 +02:00

46 lines
1.4 KiB
Python

import numpy as np
def smartsum(x, dim=None):
"""
Python implementation of the MATLAB smartsum function.
Calculates the sum while ignoring NaN and Inf values.
Parameters:
x (ndarray): Input array
dim (int, optional): Dimension along which to compute the sum
Returns:
ndarray: Sum values along the specified dimension, ignoring NaN/Inf
"""
if isinstance(x, list):
x = np.array(x)
# If dimension is not specified, find the first non-singleton dimension
if dim is None:
# Find the first dimension with size > 1
if len(x.shape) == 1:
dim = 0
else:
non_singleton_dims = [i for i, size in enumerate(x.shape) if size > 1]
dim = non_singleton_dims[0] if non_singleton_dims else 0
# Create a mask of finite values
mask = np.isfinite(x)
# Replace non-finite values with zeros
x_clean = np.where(mask, x, 0)
# Sum the finite values
result = np.sum(x_clean, axis=dim)
# Set result to NaN where all values are non-finite
if dim == 0:
result[np.sum(mask, axis=dim) == 0] = np.nan
elif dim == 1:
result[np.sum(mask, axis=dim) == 0] = np.nan
else:
all_nan_indices = np.where(np.sum(mask, axis=dim) == 0)
if all_nan_indices[0].size > 0:
result[all_nan_indices] = np.nan
return result