您好, 欢迎来到 !    登录 | 注册 | | 设为首页 | 收藏本站

在numpy或pytorch中自动获取对角矩阵条纹

在numpy或pytorch中自动获取对角矩阵条纹

stride_tricks 做到这一点:

>>> import numpy as np
>>> 
>>> def stripe(a):
...    a = np.asanyarray(a)
...    *sh, i, j = a.shape
...    assert i >= j
...    *st, k, m = a.strides
...    return np.lib.stride_tricks.as_strided(a, (*sh, i-j+1, j), (*st, k, k+m))
... 
>>> a = np.arange(24).reshape(6, 4)
>>> a
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15],
       [16, 17, 18, 19],
       [20, 21, 22, 23]])
>>> stripe(a)
array([[ 0,  5, 10, 15],
       [ 4,  9, 14, 19],
       [ 8, 13, 18, 23]])

如果a一个数组,则会创建一个可写的视图,这意味着,如果您觉得这样,可以执行以下操作:

>>> stripe(a)[...] *= 10
>>> a
array([[  0,   1,   2,   3],
       [ 40,  50,   6,   7],
       [ 80,  90, 100,  11],
       [ 12, 130, 140, 150],
       [ 16,  17, 180, 190],
       [ 20,  21,  22, 230]])

更新:可以相同的方式获得从左下到右上的条纹。仅很小的复杂性:它不基于与原始数组相同的地址。

>>> def reverse_stripe(a):
...     a = np.asanyarray(a)
...     *sh, i, j = a.shape
...     assert i >= j
...     *st, k, m = a.strides
...     return np.lib.stride_tricks.as_strided(a[..., j-1:, :], (*sh, i-j+1, j), (*st, k, m-k))
... 
>>> a = np.arange(24).reshape(6, 4)
>>> reverse_stripe(a)
array([[12,  9,  6,  3],
       [16, 13, 10,  7],
       [20, 17, 14, 11]])
其他 2022/1/1 18:34:02 有435人围观

撰写回答


你尚未登录,登录后可以

和开发者交流问题的细节

关注并接收问题和回答的更新提醒

参与内容的编辑和改进,让解决方法与时俱进

请先登录

推荐问题


联系我
置顶