- Published on
Einsum (Einstein Summation)
- Authors
- Name
- Inhwan Cho
einsum
참고 : Aladdin youtube
Einsum
은 Einstein Summation Convention으로 연산을 하는 방법입니다.연산을 통해 내적(Dot products), 외적(Outer porducts), 전치(Transpose), 행렬곱(Matmul) 등을 표현할 수 있으며,
형태(dim, shape)을 관리할 때 매우 유용하다.
einsum은 numpy, torch, tensorflow 에서 사용 가능 하다.
ex- numpy.einsum(), torch.einsum(), tensorflow.einsum()
- 간단하게 아래처럼 사용할 수 있습니다.
- 차원 표현으로
ijk...
를 많이 사용됩니다. - a,b 중 같은 차원이라면 동일한 알파벳으로 입력해주기.
einsum의 통상적인 사용방법은 다음과 같습니다. torch인 a.shape==(2,3,4),b.shape(3,4,1)가 있다면,
torch.einsum('ijk , jka -> jki' , [a,b])
결과는 [3,4,2] 라는 식으로 나옵니다.
- 수학적으로 표현하자면 너무 복잡해지니 예시를 통해 간단한 사용 방법을 익혀봅시다.
예시
# MATRIX TRANSPOSE
import torch
a = torch.arange(6).reshape(2, 3)
print(a)
torch.einsum('ij->ji', [a])
tensor([[0, 1, 2],
[3, 4, 5]])
tensor([[0, 3],
[1, 4],
[2, 5]])
# SUM
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15) # 6!
# COLUMN SUM
a = torch.arange(6).reshape(2, 3)
print(a)
torch.einsum('ij->j', [a])
tensor([[0, 1, 2],
[3, 4, 5]])
# 0+3 , 1+4, 2+5
tensor([3, 5, 7])
# ROW SUM
a = torch.arange(6).reshape(2, 3)
print(a)
torch.einsum('ij->i', [a])
tensor([[0, 1, 2], #0+1+2->3
[3, 4, 5]]) #3+4+5->12
tensor([ 3, 12])
# MATRIX-VECTOR MULTIPLICATION
a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])
tensor([ 5, 14])
# 행렬곱과 값이 동일
np.matmul(a,b)
tensor([ 5, 14])
# MATRIX-MATRIX MULTIPLICATION
a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)
torch.einsum('ik,kj->ij', [a, b])
tensor([[ 25, 28, 31, 34, 37],
[ 70, 82, 94, 106, 118]])
# DOT PRODUCT(vector)
a = torch.arange(3)
b = torch.arange(3,6) # [3, 4, 5]
torch.einsum('i,i->', [a, b])
tensor(14)
# DOT PRODUCT(matrix)
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
tensor(145)
# HADAMARD PRODUCT
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
# OUTER PRODUCT
a = torch.arange(3)
b = torch.arange(3,7) #[3, 4, 5, 6]
c = torch.einsum('i,j->ij', [a, b])
print(a.shape,b.shape,c.shape)
c
torch.Size([3]) torch.Size([4]) torch.Size([3, 4])
tensor([[ 0, 0, 0, 0],
[ 3, 4, 5, 6],
[ 6, 8, 10, 12]])