Source code for nlpatl.sampling.uncertainty.mismatch

from typing import Tuple, List, Union
from scipy.stats import entropy
import numpy as np

from nlpatl.sampling import Sampling


[docs]class MismatchSampling(Sampling): """ Sampling data points according to the mismatch. Pick the N data points randomly. :param name: Name of this sampling :type name: str """ def __init__(self, name: str = "mismatch_sampling"): super().__init__(name=name)
[docs] def sample( self, data1: Union[List[str], List[int], List[float], np.ndarray], data2: Union[List[str], List[int], List[float], np.ndarray], num_sample: int, ) -> Tuple[np.ndarray, np.ndarray]: assert len(data1) == len(data2), "Two list of data have different size." # Find mismatch mismatch_indices = [] for i, (d1, d2) in enumerate(zip(data1, data2)): if d1 != d2: mismatch_indices.append(i) num_node = min(num_sample, len(mismatch_indices)) mismatch_indices = np.random.choice(mismatch_indices, num_node, replace=False) return mismatch_indices, None