AI/framework

torch.sqeeze()

bitpoint 2024. 7. 31. 07:17
import torch

x = torch.rand(1, 1, 2, 3)
x = x.squeeze() # [1, 1, 2, 3] -> [2, 3]
print(x)

 

tensor([[0.7293, 0.5819, 0.9878],
        [0.6379, 0.6204, 0.8290]])

 

y = torch.rand(1, 1, 4)
z = y.squeeze()[2] # [1, 1, 4] -> [4]
print(z)


tensor(0.9240)

'AI > framework' 카테고리의 다른 글

torch.arrange  (0) 2024.10.10
torch.gather  (0) 2024.07.31
pytorch vs keras  (1) 2024.03.24