Understanding NumPy#39;s einsum(了解 NumPy 的 einsum)
问题描述
我很难准确理解
您可以看到标签 j
重复了 - 这意味着我们将 A
的行与 B
的列相乘.此外,标签 j
不包含在输出中 - 我们正在对这些产品求和.标签 i
和 k
被保留用于输出,所以我们得到一个二维数组.
将此结果与标签 j
是 not 相加的数组进行比较可能会更清楚.下面,在左侧,您可以看到由编写 np.einsum('ij,jk->ijk', A, B)
产生的 3D 数组(即我们保留了标签 j
):
求和轴 j
给出预期的点积,如右图所示.
一些练习
为了更好地了解 einsum
,使用下标表示法实现熟悉的 NumPy 数组操作会很有用.任何涉及乘法和求和轴组合的东西都可以使用 einsum
编写.
令 A 和 B 是两个长度相同的一维数组.例如,A = np.arange(10)
和 B = np.arange(5, 15)
.
A
的总和可以写成:np.einsum('i->', A)
逐元素乘法,
A * B
,可以写成:np.einsum('i,i->i', A, B)
内积或点积,
np.inner(A, B)
或np.dot(A, B)
,可以写成:np.einsum('i,i->', A, B) # 或者直接使用 'i,i'
外积
np.outer(A, B)
可以写成:np.einsum('i,j->ij', A, B)
对于 2D 数组,C
和 D
,只要轴的长度兼容(长度相同或其中一个长度为 1),这里有一个几个例子:
C
的轨迹(主对角线之和),np.trace(C)
,可以写成:np.einsum('ii', C)
C
的元素乘法和D
的转置,C * DT
,可以写成:np.einsum('ij,ji->ij', C, D)
将
C
的每个元素乘以数组D
(形成4D数组),C[:, :, None, None]* D
,可以写成:np.einsum('ij,kl->ijkl', C, D)
I'm struggling to understand exactly how einsum
works. I've looked at the documentation and a few examples, but it's not seeming to stick.
Here's an example we went over in class:
C = np.einsum("ij,jk->ki", A, B)
for two arrays: A
and B
.
I think this would take A^T * B
, but I'm not sure (it's taking the transpose of one of them right?). Can anyone walk me through exactly what's happening here (and in general when using einsum
)?
(Note: this answer is based on a short blog post about einsum
I wrote a while ago.)
What does einsum
do?
Imagine that we have two multi-dimensional arrays, A
and B
. Now let's suppose we want to...
- multiply
A
withB
in a particular way to create new array of products; and then maybe - sum this new array along particular axes; and then maybe
- transpose the axes of the new array in a particular order.
There's a good chance that einsum
will help us do this faster and more memory-efficiently than combinations of the NumPy functions like multiply
, sum
and transpose
will allow.
How does einsum
work?
Here's a simple (but not completely trivial) example. Take the following two arrays:
A = np.array([0, 1, 2])
B = np.array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
We will multiply A
and B
element-wise and then sum along the rows of the new array. In "normal" NumPy we'd write:
>>> (A[:, np.newaxis] * B).sum(axis=1)
array([ 0, 22, 76])
So here, the indexing operation on A
lines up the first axes of the two arrays so that the multiplication can be broadcast. The rows of the array of products are then summed to return the answer.
Now if we wanted to use einsum
instead, we could write:
>>> np.einsum('i,ij->i', A, B)
array([ 0, 22, 76])
The signature string 'i,ij->i'
is the key here and needs a little bit of explaining. You can think of it in two halves. On the left-hand side (left of the ->
) we've labelled the two input arrays. To the right of ->
, we've labelled the array we want to end up with.
Here is what happens next:
A
has one axis; we've labelled iti
. AndB
has two axes; we've labelled axis 0 asi
and axis 1 asj
.By repeating the label
i
in both input arrays, we are tellingeinsum
that these two axes should be multiplied together. In other words, we're multiplying arrayA
with each column of arrayB
, just likeA[:, np.newaxis] * B
does.Notice that
j
does not appear as a label in our desired output; we've just usedi
(we want to end up with a 1D array). By omitting the label, we're tellingeinsum
to sum along this axis. In other words, we're summing the rows of the products, just like.sum(axis=1)
does.
That's basically all you need to know to use einsum
. It helps to play about a little; if we leave both labels in the output, 'i,ij->ij'
, we get back a 2D array of products (same as A[:, np.newaxis] * B
). If we say no output labels, 'i,ij->
, we get back a single number (same as doing (A[:, np.newaxis] * B).sum()
).
The great thing about einsum
however, is that it does not build a temporary array of products first; it just sums the products as it goes. This can lead to big savings in memory use.
A slightly bigger example
To explain the dot product, here are two new arrays:
A = array([[1, 1, 1],
[2, 2, 2],
[5, 5, 5]])
B = array([[0, 1, 0],
[1, 1, 0],
[1, 1, 1]])
We will compute the dot product using np.einsum('ij,jk->ik', A, B)
. Here's a picture showing the labelling of the A
and B
and the output array that we get from the function:
You can see that label j
is repeated - this means we're multiplying the rows of A
with the columns of B
. Furthermore, the label j
is not included in the output - we're summing these products. Labels i
and k
are kept for the output, so we get back a 2D array.
It might be even clearer to compare this result with the array where the label j
is not summed. Below, on the left you can see the 3D array that results from writing np.einsum('ij,jk->ijk', A, B)
(i.e. we've kept label j
):
Summing axis j
gives the expected dot product, shown on the right.
Some exercises
To get more of a feel for einsum
, it can be useful to implement familiar NumPy array operations using the subscript notation. Anything that involves combinations of multiplying and summing axes can be written using einsum
.
Let A and B be two 1D arrays with the same length. For example, A = np.arange(10)
and B = np.arange(5, 15)
.
The sum of
A
can be written:np.einsum('i->', A)
Element-wise multiplication,
A * B
, can be written:np.einsum('i,i->i', A, B)
The inner product or dot product,
np.inner(A, B)
ornp.dot(A, B)
, can be written:np.einsum('i,i->', A, B) # or just use 'i,i'
The outer product,
np.outer(A, B)
, can be written:np.einsum('i,j->ij', A, B)
For 2D arrays, C
and D
, provided that the axes are compatible lengths (both the same length or one of them of has length 1), here are a few examples:
The trace of
C
(sum of main diagonal),np.trace(C)
, can be written:np.einsum('ii', C)
Element-wise multiplication of
C
and the transpose ofD
,C * D.T
, can be written:np.einsum('ij,ji->ij', C, D)
Multiplying each element of
C
by the arrayD
(to make a 4D array),C[:, :, None, None] * D
, can be written:np.einsum('ij,kl->ijkl', C, D)
这篇关于了解 NumPy 的 einsum的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!
本文标题为:了解 NumPy 的 einsum
基础教程推荐
- 使用Python匹配Stata加权xtil命令的确定方法? 2022-01-01
- 如何在Python中绘制多元函数? 2022-01-01
- 如何在 Python 中检测文件是否为二进制(非文本)文 2022-01-01
- 将 YAML 文件转换为 python dict 2022-01-01
- 合并具有多索引的两个数据帧 2022-01-01
- 使用 Google App Engine (Python) 将文件上传到 Google Cloud Storage 2022-01-01
- 症状类型错误:无法确定关系的真值 2022-01-01
- 使 Python 脚本在 Windows 上运行而不指定“.py";延期 2022-01-01
- 哪些 Python 包提供独立的事件系统? 2022-01-01
- Python 的 List 是如何实现的? 2022-01-01