본문 바로가기
AI-Tech 부스트캠프/파이토치

torch.stack() 과 torch.cat()

by Alan_Kim 2023. 4. 8.
728x90
반응형

 

torch.stack() 과 torch.cat() 모두 행렬을 연결하는데 쓰인다.

하지만 둘의 다른점을 비교하지 않으면 가끔 혼란스러울 수 있다.

따라서 한번 정리하고 넘어가려 한다.

torch.cat()은 주어진 차원을 기준으로 주어진 텐서들을 연결(concatenate)한다.
torch.stack()은 새로운 차원으로 주어진 텐서들을 연결한다.

 

import torch

x = torch.FloatTensor([1,4])
y = torch.FloatTensor([2,5])
z = torch.FloatTensor([3,6])

print(x.shape) #torch.Size([2])

print(torch.stack([x,y,z])) 
# tensor([[1., 4.],
#        [2., 5.],
#        [3., 6.]])
print(torch.stack([x,y,z]).shape) # torch.Size([3, 2])
print(torch.stack([x,y,z], dim=1)) 
# tensor([[1., 2., 3.],
#        [4., 5., 6.]])
print(torch.stack([x,y,z], dim=1).shape) #torch.Size([2, 3])

print(torch.cat([x,y,z],dim=0)) # tensor([1., 4., 2., 5., 3., 6.])
print(torch.cat([x,y,z],dim=0).shape) # torch.Size([6])
print(torch.cat([x.unsqueeze(0), y.unsqueeze(0), z.unsqueeze(0)], dim=0)) 
# tensor([[1., 4.],
#        [2., 5.],
#        [3., 6.]])
print(torch.cat([x.unsqueeze(0), y.unsqueeze(0), z.unsqueeze(0)], dim=0).shape) # torch.Size([3, 2])
728x90
반응형

댓글