Source code for nntoolbox.vision.utils.utils
from torchvision.transforms import functional as F
from torch import Tensor
import cv2
from cv2 import imread, cvtColor
from numpy import ndarray
import numpy as np
__all__ = [
'gram_matrix', 'is_image', 'pil_to_tensor',
'tensor_to_pil', 'tensor_to_np', 'cv2_read_image',
'compute_output_shape'
]
[docs]def gram_matrix(x):
n, c, h, w = x.size()
x = x.view(n, c, -1)
return (x @ x.transpose(1,2)) / (c * h * w)
[docs]def is_image(filename):
"""
Check if filename has valid extension
:param filename:
:return: boolean indicating whether filename is a valid image filename
"""
filename = filename.lower()
return filename.endswith(".jpg") \
or filename.endswith(".png") \
or filename.endswith(".jpeg") \
or filename.endswith(".gif") \
or filename.endswith(".bmp")
[docs]def pil_to_tensor(pil, device=None):
tensor = F.to_tensor(pil).unsqueeze(0)
if device is not None:
tensor.to(device)
return tensor
[docs]def tensor_to_pil(tensor):
if len(tensor.shape) == 4:
return F.to_pil_image(tensor[0])
else:
return F.to_pil_image(tensor)
[docs]def tensor_to_np(tensor: Tensor) -> ndarray:
"""Convert the tensor image to numpy format"""
if len(tensor.shape) == 4:
return tensor.permute(0, 2, 3, 1).cpu().detach().numpy()
else:
return tensor.permute(1, 2, 0).cpu().detach().numpy()
[docs]def cv2_read_image(image_path, to_float: bool=False, flag: int=cv2.IMREAD_COLOR) -> ndarray:
"""
Read an image using cv2 and convert to RGB
:param image_path:
:param to_float: whether to convert image to float dats type:
:param flag: indicate mode for cv2 read image
:return:
"""
assert is_image(image_path)
img = imread(image_path, flag)
img = cvtColor(img, cv2.COLOR_BGR2RGB)
if to_float:
img = img / 255
return img
[docs]def compute_output_shape(inp_dim: int, padding: int, kernel_size: int, dilation: int, stride: int) -> int:
return np.floor(
(inp_dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
).astype(np.uint32)
# def is_color(image, batch: bool=True) -> bool:
# """
# Check if image(s) is colored properly (i.e has 4 channels)
# :param image:
# :param batch:
# """
# if batch:
# return len(image.shape) == 4 and
# return len(image.shape) == 3 if