Published on

torch.tril

Authors
  • avatar
    Name
    Inhwan Cho
    Twitter

torch.tril

  • torch.tril(input, diagonal=0, *, out=None)

  • 행렬의 아래쪽 삼각형 부분 (2 차원 텐서) 또는 행렬의 배치 input 을 반환합니다 .[행렬의 오른쪽 부분을(0으로 만듬)]

  • attention 구조의 mask를 만들 때 많이 사용되는 함수입니다.

  • 무슨 말인지 이해가 잘 안가실겁니다. 예제 출력 코드를 보면 바로 이해가 갈겁니다.

a = torch.ones((5, 5))
torch.tril(a)

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])


a = torch.ones((5, 5))
torch.tril(a, diagonal=1)

tensor([[1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])