- Published on
torch.tril
- Authors
- Name
- Inhwan Cho
torch.tril
torch.tril(input, diagonal=0, *, out=None)
행렬의 아래쪽 삼각형 부분 (2 차원 텐서) 또는 행렬의 배치 input 을 반환합니다 .[행렬의 오른쪽 부분을(0으로 만듬)]
attention 구조의 mask를 만들 때 많이 사용되는 함수입니다.
무슨 말인지 이해가 잘 안가실겁니다. 예제 출력 코드를 보면 바로 이해가 갈겁니다.
a = torch.ones((5, 5))
torch.tril(a)
tensor([[1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1.]])
a = torch.ones((5, 5))
torch.tril(a, diagonal=1)
tensor([[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])