torch函数

torch.clamp()

torch.clamp(input, min, max, out=None) → Tensor
参数:

  • input (Tensor) – 输入张量
  • min (Number) – 限制范围下限
  • max (Number) – 限制范围上限
  • out (Tensor, optional) – 输出张量

作用:将输入夹紧至某一区间

1
output = torch.clamp(input, 0, 1) #输入值<0则为0,>1则为1  

torch.round()

torch.round(input, out=None)
作用:返回一个新张量,将输入input张量每个元素舍入到最近的整数

1
2
3
a = torch.tensor(0.6)
b = torch.round(a) # a = 0.6, b = 1.0
a = a.round() # a = 1.0, b = 1.0

torch.sign()

torch.sign(input, out=None) → Tensor
作用:返回tensor的符号

1
2
3
4
a = torch.tensor([-1.1, 0., 1.1])
sign = torch.sign(a)
print(a.sign()) # tensor([-1., 0., 1.])
print(sign) # tensor([-1., 0., 1.])

detach() & clone()

  • torch.detach()
  1. 新的tensor会脱离计算图,不会牵扯梯度计算;
  2. 浅拷贝,和原先的tensor指向同一内存;
  • torch.clone()
  1. 新的tensor充当中间变量,会保留在计算图中,参与梯度计算(回传叠加),但是一般不会保留自身梯度;
  2. 深拷贝,开辟新的内存空间;