Source code for nlpatl.sampling.certainty.most_confidence

from typing import Tuple
import numpy as np

from nlpatl.sampling import Sampling


[docs]class MostConfidenceSampling(Sampling): """ Sampling data points if the confidence is higher than threshold. Refer to https://markcartwright.com/files/wang2019active.pdf :param threshold: Minimum probability of model prediction. Default value is 0.85 :type threshold: float :param name: Name of this sampling :type name: str """ def __init__(self, threshold: float = 0.85, name: str = "most_confidence_sampling"): super().__init__(name=name) self.threshold = threshold
[docs] def sample( self, data: np.ndarray, num_sample: int ) -> Tuple[np.ndarray, np.ndarray]: num_node = min(num_sample, len(data)) # Calucalte most confidence most_confidences = np.max(data, axis=1) indices = np.argwhere(most_confidences > self.threshold).flatten() # It is possible that no result if len(indices) > 0: indices = np.random.choice(indices, num_node) return indices, most_confidences[indices] return np.empty(0), np.empty(0)