Source code for nlpatl.models.embeddings.torchvision

from typing import List, Optional, Callable
import numpy as np

try:
	import torch
	import torchvision

	MODEL_FOR_TORCH_VISION_MAPPING_NAMES = {
		'resnet18': torchvision.models.resnet18,
		'alexnet': torchvision.models.alexnet,
		'vgg16': torchvision.models.vgg16
	}
except ImportError:
	# No installation required if not using this function
	pass

from nlpatl.models.embeddings.embeddings import Embeddings


[docs]class TorchVision(Embeddings): """ A wrapper of torch vision class. :param model_name_or_path: torch vision model name. Possible values are `resnet18`, `alexnet` and `vgg16`. :type model_name_or_path: str :param batch_size: Batch size of data processing. Default is 16 :type batch_size: int :param model_config: Model paramateters. Refer to https://pytorch.org/vision/stable/models.html :type model_config: dict :param transform: Preprocessing function :type transform: :class: `torchvision.transforms` :param name: Name of this embeddings :type name: str >>> import nlpatl.models.embeddings as nme >>> model = nme.TrochVision() """ def __init__(self, model_name_or_path: str, batch_size: int = 16, model_config: dict = {'pretrained': True}, transform: Optional[Callable] = None, name: str = 'torchvision'): super().__init__(batch_size=batch_size, name=name) self.model_name_or_path = model_name_or_path self.model_config = model_config self.transform = transform if model_name_or_path in MODEL_FOR_TORCH_VISION_MAPPING_NAMES: self.model = MODEL_FOR_TORCH_VISION_MAPPING_NAMES[model_name_or_path]( **model_config) self.model.eval() else: raise ValueError('`{}` does not support. Supporting {} only'.format( model_name_or_path, '`' + '`,`'.join( MODEL_FOR_TORCH_VISION_MAPPING_NAMES.keys()) + '`')) @staticmethod def get_mapping() -> dict: return MODEL_FOR_TORCH_VISION_MAPPING_NAMES
[docs] def convert(self, x: List[np.ndarray]) -> np.ndarray: results = [] for batch_inputs in self.batch(x, self.batch_size): with torch.no_grad(): features = [ self.transform(img) if self.transform else img for img in batch_inputs] results.append(self.model(torch.stack(features))) return torch.cat(results).detach().numpy()