nlpatl.models.embeddings.torchvision
- class nlpatl.models.embeddings.torchvision.TorchVision(model_name_or_path, batch_size=16, model_config={'pretrained': True}, transform=None, name='torchvision')[source]
Bases:
nlpatl.models.embeddings.embeddings.Embeddings
A wrapper of torch vision class.
- Parameters
model_name_or_path (str) – torch vision model name. Possible values are resnet18, alexnet and vgg16.
batch_size (int) – Batch size of data processing. Default is 16
model_config (dict) – Model paramateters. Refer to https://pytorch.org/vision/stable/models.html
transform – Preprocessing function
name (str) – Name of this embeddings
>>> import nlpatl.models.embeddings as nme >>> model = nme.TrochVision()