nlpatl.models.embeddings.torchvision

class nlpatl.models.embeddings.torchvision.TorchVision(model_name_or_path, batch_size=16, model_config={'pretrained': True}, transform=None, name='torchvision')[source]

Bases: nlpatl.models.embeddings.embeddings.Embeddings

A wrapper of torch vision class.

Parameters
  • model_name_or_path (str) – torch vision model name. Possible values are resnet18, alexnet and vgg16.

  • batch_size (int) – Batch size of data processing. Default is 16

  • model_config (dict) – Model paramateters. Refer to https://pytorch.org/vision/stable/models.html

  • transform – Preprocessing function

  • name (str) – Name of this embeddings

>>> import nlpatl.models.embeddings as nme
>>> model = nme.TrochVision()
convert(x)[source]
Parameters

x (np.ndarray) – Raw features

Returns

Vectors of features

Return type

np.ndarray