Source code for nlpatl.models.embeddings.sentence_transformers

from typing import List
import numpy as np

try:
    import torch
    from sentence_transformers import SentenceTransformer
except ImportError:
    # No installation required if not using this function
    pass

from nlpatl.models.embeddings import Embeddings


[docs]class SentenceTransformers(Embeddings): """ A wrapper of transformers class. :param model_name_or_path: sentence transformers model name. :type model_name_or_path: str :param batch_size: Batch size of data processing. Default is 16 :type batch_size: int :param model_config: Model paramateters. Refer to https://www.sbert.net/docs/pretrained_models.html :type model_config: dict :param name: Name of this embeddings :type name: str >>> import nlpatl.models.embeddings as nme >>> model = nme.SentenceTransformers() """ def __init__( self, model_name_or_path: str, batch_size: int = 16, name: str = "sentence_transformers", ): super().__init__(batch_size=batch_size, name=name) self.model_name_or_path = model_name_or_path self.model = SentenceTransformer(model_name_or_path) self.model.eval()
[docs] def convert(self, x: List[str]) -> np.ndarray: results = [] for i, batch_inputs in enumerate(self.batch(x, self.batch_size)): with torch.no_grad(): results.extend(self.model.encode(batch_inputs)) return np.array(results)