Source code for nlpatl.models.classification.sklearn_classification

"""
	sci-kit learn classification wrapper
"""

from typing import List, Union
from collections import defaultdict
from sklearn.linear_model import (
	LogisticRegression
)
from sklearn.svm import (
	SVC,
	LinearSVC
)
from sklearn.ensemble import (
	RandomForestClassifier
)
import numpy as np

from nlpatl.models.classification.classification import Classification
from nlpatl.dataset import Dataset

MODEL_FOR_SKLEARN_CLASSIFICATION_MAPPING_NAMES = {
	'logistic_regression': LogisticRegression,
	'svc': SVC,
	'linear_svc': LinearSVC,
	'random_forest': RandomForestClassifier
}


[docs]class SkLearnClassification(Classification): """ A wrapper of sci-kit learn classification class. :param model_name: sci-kit learn classification model name. Possible values are `logistic_regression`, `svc`, `linear_svc` and `random_forest`. :type model_name: str :param model_config: Model paramateters. Refer to https://scikit-learn.org/stable/modules/classes.html#module-sklearn.linear_model :type model_config: dict :param name: Name of this classification :type name: str >>> import nlpatl.models.classification as nmcla >>> model = nmcla.SkLearnClassification() """ def __init__(self, model_name: str = 'logistic_regression', model_config: dict = {}, name: str = 'sklearn_classification'): super().__init__(name=name) self.model_name = model_name self.model_config = model_config if model_name in MODEL_FOR_SKLEARN_CLASSIFICATION_MAPPING_NAMES: self.model = MODEL_FOR_SKLEARN_CLASSIFICATION_MAPPING_NAMES[model_name]( **model_config) else: raise ValueError('`{}` does not support. Supporting {} only'.format( model_name, '`' + '`'.join( MODEL_FOR_SKLEARN_CLASSIFICATION_MAPPING_NAMES.keys()) + '`')) @staticmethod def get_mapping() -> dict: return MODEL_FOR_SKLEARN_CLASSIFICATION_MAPPING_NAMES
[docs] def train(self, x: np.ndarray, y: Union[np.ndarray, List[str], List[int], List[List[str]], List[List[int]]]): self.build_label_encoder(y) y_encoded = [self.label_encoder[lab] for lab in y] self.model.fit(x, y_encoded)
[docs] def predict_proba(self, x: np.ndarray, predict_config: dict={}) -> Dataset: """ :param x: Raw features :type x: np.ndarray :param predict_config: Model prediction paramateters. Refer to https://scikit-learn.org/stable/modules/classes.html#module-sklearn.linear_model :type model_config: dict :return: Feature and probabilities :rtype: :class:`nlptatl.dataset.Dataset` """ probs = self.model.predict_proba(x, **predict_config) preds = np.argmax(probs, axis=1) return Dataset( features=x, values=probs, groups=preds.tolist())