from typing import Tuple, List, Union
from scipy.stats import entropy
import numpy as np
from nlpatl.sampling import Sampling
[docs]class MismatchSampling(Sampling):
"""
Sampling data points according to the mismatch. Pick the N data points
randomly.
:param name: Name of this sampling
:type name: str
"""
def __init__(self, name: str = 'mismatch_sampling'):
super().__init__(name=name)
[docs] def sample(self, data1: Union[List[str], List[int], List[float], np.ndarray],
data2: Union[List[str], List[int], List[float], np.ndarray],
num_sample: int) -> Tuple[np.ndarray, np.ndarray]:
assert len(data1) == len(data2), 'Two list of data have different size.'
# Find mismatch
mismatch_indices = []
for i, (d1, d2) in enumerate(zip(data1, data2)):
if d1 != d2:
mismatch_indices.append(i)
num_node = min(num_sample, len(mismatch_indices))
mismatch_indices = np.random.choice(
mismatch_indices, num_node, replace=False)
return mismatch_indices, None