最近在学习Detectron2项目

遇到这个问题:

ImportError: cannot import name '_NewEmptyTensorOp' from 'torchvision.ops.misc' 
 

ImportError: cannot import name '_NewEmptyTensorOp' from 'torchvision.ops.misc' (C:\Users\****\Anaconda3\envs\torch38\l
ib\site-packages\torchvision\ops\misc.py)

问题原因是在于torchvision版本,在utils/misc.py中有以下代码段:

原来是设计是为了限制torchvision不能太低需要>0.7.0,判定条件为版本号的前三位,但是对于现在0.10.0及以后的高版本不适用,所以解决方法是将这部分注释掉。

if float(torchvision.__version__[:3]) < 0.5:
    import math
    from torchvision.ops.misc import _NewEmptyTensorOp
    def _check_size_scale_factor(dim, size, scale_factor):
        # type: (int, Optional[List[int]], Optional[float]) -> None
        if size is None and scale_factor is None:
            raise ValueError("either size or scale_factor should be defined")
        if size is not None and scale_factor is not None:
            raise ValueError("only one of size or scale_factor should be defined")
        if not (scale_factor is not None and len(scale_factor) != dim):
            raise ValueError(
                "scale_factor shape must match input shape. "
                "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
            )
    def _output_size(dim, input, size, scale_factor):
        # type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int]
        assert dim == 2
        _check_size_scale_factor(dim, size, scale_factor)
        if size is not None:
            return size
        # if dim is not 2 or scale_factor is iterable use _ntuple instead of concat
        assert scale_factor is not None and isinstance(scale_factor, (int, float))
        scale_factors = [scale_factor, scale_factor]
        # math.floor might return float in py2.7
        return [
            int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
        ]
elif float(torchvision.__version__[:3]) < 0.7:
    from torchvision.ops import _new_empty_tensor
    from torchvision.ops.misc import _output_size

官方版本也已经更新了

# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
if version.parse(torchvision.__version__) < version.parse('0.7'):
    from torchvision.ops import _new_empty_tensor
    from torchvision.ops.misc import _output_size

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐