Classifying text with a custom classification model
You can train your own models for text classification using strong classification algorithms from three different families:
- Classic machine learning using SVM (Support Vector Machines)
- Deep learning using CNN (Convolutional Neural Networks)
- A transformer-based algorithm using a pre-trained transformer model: Slate IBM Foundation model
The Watson Natural Language Processing library also offers an easy to use Ensemble classifier that combines different classification algorithms and majority voting.
The algorithms support multi-label and multi-class tasks and special cases, like if the document belongs to one class only (single-label task), or binary classification tasks.
Training classification models is CPU and memory intensive. Depending on the size of your training data, the environment might not be large enough to complete the training. If you run into issues with the notebook kernel during training, create a custom notebook environment with a larger amount of CPU and memory, and use that to run your notebook. Especially for transformer-based algorithms, you should use a GPU-based environment, if it is available to you. See Creating your own environment template.
Topic sections:
- Input data format for training
- Input data requirements
- Stopwords
- Training SVM algorithms
- Training the CNN algorithm
- Training the transformer algorithm by using the Slate IBM Foundation model
- Training a custom transformer model by using a model provided by Hugging Face
- Training an ensemble model
- Training best practices
- Applying the model on new data
- Choosing the right algorithm for your use case
Input data format for training
Classification blocks accept training data in CSV and JSON formats.
-
The CSV Format
The CSV file should contain no header. Each row in the CSV file represents an example record. Each record has one or more columns, where the first column represents the text and the subsequent columns represent the labels associated with that text.
Note:
- The SVM and CNN algorithms do not support training data where an instance has no labels. So, if you are using the SVM algorithm, or the CNN algorithm, or an Ensemble including one of these algorithms, each CSV row must have at least one label, i.e., 2 columns.
- The BERT-based and Slate-based Transformer algorithms support training data where each instance has 0, 1 or more than one label.
Example 1,label 1 Example 2,label 1,label 2
-
The JSON Format
The training data is represented as an array with multiple JSON objects. Each JSON object represents one training instance, and must have a text and a labels field. The text represents the training example, and labels stores the labels associated with the example (0, 1, or more than one label).
[ { "text": "Example 1", "labels": ["label 1"] }, { "text": "Example 2", "labels": ["label 1", "label 2"] }, { "text": "Example 3", "labels": [] } ]
Note:
"labels": []
denotes an example with no labels. The SVM and CNN algorithms do not support training data where an instance has no labels. So, if you are using the SVM algorithm, or the CNN algorithm, or an Ensemble including one of these algorithms, each JSON object must have at least one label.- The BERT-based and Slate-based Transformer algorithms support training data where each instance has 0, 1 or more than one label.
Input data requirements
For SVM and CNN algorithms:
- Minimum number of unique labels required: 2
- Minimum number of text examples required per label: 5
For the BERT-based and Slate-based Transformer algorithms:
- Minimum number of unique labels required: 1
- Minimum number of text examples required per label: 5
Training data in CSV or JSON format is converted to a DataStream before training. Instead of training data files, you can also pass data streams directly to the training functions of classification blocks.
Stopwords
You can provide your own stopwords that will be removed during preprocessing. Stopwords file inputs are expected in a standard format: a single text file with one phrase per line. Stopwords can be provided as a list or as a file in a standard format.
Stopwords can be used only with the Ensemble classifier.
Training SVM algorithms
SVM is a support vector machine classifier that can be trained using predictions on any kind of input provided by the embedding or vectorization blocks as feature vectors, for example, by USE
(Universal Sentence Encoder) embeddings
and TF-IDF
vectorizers. It supports multi-class and multi-label text classification and produces confidence scores by using Platt Scaling.
For all options that are available for configuring SVM training, enter:
help(watson_nlp.blocks.classification.svm.SVM.train)
To train SVM algorithms:
-
Begin with these preprocessing steps:
import watson_nlp from watson_core.data_model.streams.resolver import DataStreamResolver from watson_nlp.blocks.classification.svm import SVM training_data_file = "<ADD TRAINING DATA FILE PATH>" # Create datastream from training data data_stream_resolver = DataStreamResolver(target_stream_type=list, expected_keys={'text': str, 'labels': list}) training_data = data_stream_resolver.as_data_stream(training_data_file) # Load a Syntax model syntax_model = watson_nlp.load('syntax_izumo_en_stock') # Create Syntax stream text_stream, labels_stream = training_data[0], training_data[1] syntax_stream = syntax_model.stream(text_stream)
-
Train the classification model using USE embeddings. See Pretrained USE embeddings that are included in Cloud Pak for Data for a list of available, pretrained blocks.
# download embedding use_embedding_model = watson_nlp.load('embedding_use_en_stock') use_train_stream = use_embedding_model.stream(syntax_stream, doc_embed_style='raw_text') # NOTE: doc_embed_style can be changed to `avg_sent` as well. For more information check the documentation for Embeddings # Or the USE run function API docs use_svm_train_stream = watson_nlp.data_model.DataStream.zip(use_train_stream, labels_stream) # Train SVM using Universal Sentence Encoder (USE) training stream classification_model = SVM.train(use_svm_train_stream)
Training the CNN algorithm
CNN is a simple convolutional network architecture, built for multi-class and multi-label text classification on short texts. It utilizes GloVe embeddings. GloVe embeddings encode word-level semantics into a vector space. The GloVe embeddings for each language are trained on the Wikipedia corpus in that language. For information on using GloVe embeddings, see GloVe Embeddings.
For all the options that are available for configuring CNN training, enter:
help(watson_nlp.blocks.classification.cnn.CNN.train)
To train CNN algorithms:
import watson_nlp
from watson_core.data_model.streams.resolver import DataStreamResolver
from watson_nlp.blocks.classification.cnn import CNN
training_data_file = "<ADD TRAINING DATA FILE PATH>"
# Create datastream from training data
data_stream_resolver = DataStreamResolver(target_stream_type=list, expected_keys={'text': str, 'labels': list})
training_data = data_stream_resolver.as_data_stream(training_data_file)
# Load a Syntax model
syntax_model = watson_nlp.load('syntax_izumo_en_stock')
# Create Syntax stream
text_stream, labels_stream = training_data[0], training_data[1]
syntax_stream = syntax_model.stream(text_stream)
# Download GloVe embeddings
glove_embedding_model = watson_nlp.load('embedding_glove_en_stock')
# Train CNN
classification_model = CNN.train(watson_nlp.data_model.DataStream.zip(syntax_stream, labels_stream), embedding=glove_embedding_model.embedding)
Training the transformer algorithm by using the IBM Slate model
The transformer algorithm using a pretrained Slate IBM Foundation model can be used for multi-class and multi-label text classification on short texts.
For all the options available for configuring Transformer training, enter:
help(watson_nlp.blocks.classification.transformer.Transformer.train)
For a list of available Slate models, see this table:
Model | Description |
---|---|
pretrained-model_slate.153m.distilled_many_transformer_multilingual_uncased |
Generic, multi-purpose model |
pretrained-model_slate.125m.finance_many_transformer_en_cased |
Model pretrained on finance content |
pretrained-model_slate.110m.cybersecurity_many_transformer_en_uncased |
Model pretrained on cybersecurity content |
pretrained-model_slate.125m.biomedical_many_transformer_en_cased |
Model pretrained on biomedical content |
To train Transformer algorithms:
import watson_nlp
from watson_nlp.blocks.classification.transformer import Transformer
from watson_core.data_model.streams.resolver import DataStreamResolver
training_data_file = "train_data.json"
# create datastream from training data
data_stream_resolver = DataStreamResolver(target_stream_type=list, expected_keys={'text': str, 'labels': list})
train_stream = data_stream_resolver.as_data_stream(training_data_file)
# Load pre-trained Slate model
pretrained_model_resource = watson_nlp.load('pretrained-model_slate.153m.distilled_many_transformer_multilingual_uncased ')
# Train model
classification_model = Transformer.train(train_stream, pretrained_model_resource)
Training a custom transformer model by using a model provided by Hugging Face
You can train your custom transformer-based model by using a pretrained model from Hugging Face.
To use a Hugging Face model, specify the model name as the pretrained_model_resource
parameter in the train
method of watson_nlp.blocks.classification.transformer.Transformer
. Go to https://huggingface.co/models to copy the model name.
To get a list of all the options available for configuring a transformer training, type this code:
help(watson_nlp.blocks.classification.transformer.Transformer.train)
For information on how to train transformer algorithms, refer to this code example:
import watson_nlp
from watson_nlp.blocks.classification.transformer import Transformer
from watson_core.data_model.streams.resolver import DataStreamResolver
training_data_file = "train_data.json"
# create datastream from training data
data_stream_resolver = DataStreamResolver(target_stream_type=list, expected_keys={'text': str, 'labels': list})
train_stream = data_stream_resolver.as_data_stream(training_data_file)
# Specify the name of the Hugging Face model
huggingface_model_name = 'xml-roberta-base'
# Train model
classification_model = Transformer.train(train_stream, pretrained_model_resource=huggingface_model_name)
Training an ensemble model
The Ensemble model is a weighted ensemble of these three algorithms: CNN, SVM with TF-IDF and SVM with USE. It computes the weighted mean of a set of classification predictions using confidence scores. The ensemble model is very easy to use.
The GenericEnsemble classifier allows more flexibility for the user to choose from the three base classifiers TFIDF-SVM, USE-SVM and CNN. For texts ranging from 50 to 1000 characters, using the combination of TFIDF-SVM and USE-SVM classifiers often yields a good balance of quality and performance. On some medium or long documents (500-1000+ characters), adding the CNN to the Ensemble could help increase quality, but it usually comes with a significant runtime performance impact (lower throughput and increased model loading time).
For all of the options available for configuring Ensemble training, enter:
help(watson_nlp.workflows.classification.GenericEnsemble)
To train Ensemble algorithms:
import watson_nlp
from watson_nlp.workflows.classification import GenericEnsemble
from watson_nlp.workflows.classification.base_classifier import GloveCNN
from watson_nlp.workflows.classification.base_classifier import TFidfSvm
from watson_nlp.workflows.classification.base_classifier import UseSvm
training_data_file = "<ADD TRAINING DATA FILE PATH>"
# Create datastream from training data
data_stream_resolver = DataStreamResolver(target_stream_type=list, expected_keys={'text': str, 'labels': list})
training_data = data_stream_resolver.as_data_stream(training_data_file)
# Syntax Model
syntax_model = watson_nlp.load('syntax_izumo_en_stock')
# USE Embedding Model
use_model = watson_nlp.load('embedding_use_en_stock')
# GloVE Embedding model
glove_model = watson_nlp.load('embedding_glove_en_stock')
ensemble_model = GenericEnsemble.train(training_data, syntax_model,
base_classifiers_params=[
TFidfSvm.TrainParams(syntax_model=syntax_model),
GloveCNN.TrainParams(syntax_model=syntax_model, glove_embedding_model=glove_model, cnn_epochs=5),
UseSvm.TrainParams(syntax_model=syntax_model, use_embedding_model=use_model, doc_embed_style='raw_text')],
use_ewl=True)
Pretrained stopword models available out-of-the-box
The text model for identifying stopwords is used in training the document classification ensemble model.
The following table lists the pretrained stopword models and the language codes that are supported (xx
stands for the language code). For a list of the language codes and the corresponding language, see Language codes.
Resource class | Model name | Supported languages |
---|---|---|
text |
text_stopwords_classification_ensemble_xx_stock |
ar, de, es, en, fr, it, ja, ko |
Training best practices
There are certain constraints on the quality and quantity of data to ensure that the classifications model training can complete in a reasonable amount of time and also meets various performance criteria. These are listed below. Note that none are hard restrictions. However, the further one deviates from these guidelines, the greater the chance that the model fails to train or the model will not be satisfactory.
-
Data quantity
- The highest number of classes classification model has been tested on is ~1200.
- The best suited text size for training and testing data for classification is around 3000 code points. However, larger texts can also be processed, but the runtime performance might be slower.
- Training time will increase based on the number of examples and number of labels.
- Inference time will increased based on the number of labels.
-
Data quality
- Size of each sample (for example, number of phrases in each training sample) can affect quality.
- Class separation is important. In other words, classes among the training (and test) data should be semantically distinguishable from each another in order to avoid misclassifications. Since the classifier algorithms in Watson Natural Language Processing rely on word embeddings, training classes that contain text examples with too much semantic overlap may make high-quality classification computationally intractable. While more sophisticated heuristics may exist for assessing the semantic similarity between classes, you should start with a simple "eye test" of a few examples from each class to discern whether or not they seem adequately separated.
- It is recommended to use balanced data for training. Ideally there should be roughly equal numbers of examples from each class in the training data, otherwise the classifiers may be biased towards classes with larger representation in the training data.
- It is best to avoid circumstances where some classes in the training data are highly under-represented as compared to other classes.
Limitations and caveats:
- The CNN classification block has a predefined sequence length of 1000 code points. This limit can be configured at train time by changing the parameter
max_phrase_len
. There is no maximum limit for this parameter, but increasing the maximum phrase length will affect CPU and memory consumption. - SVM blocks do not have such limit on sequence length and can be used with longer texts.
Applying the model on new data
After you have trained the model on a data set, apply the model on new data using the run()
method, as you would use on any of the existing pre-trained blocks.
Sample code
-
For the Ensemble models:
# run Ensemble model on new text ensemble_prediction = ensemble_classification_model.run("new input text")
-
For SVM and CNN models, for example for CNN:
# run Syntax model first syntax_result = syntax_model.run("new input text") # run CNN model on top of syntax result cnn_prediction = cnn_classification_model.run(syntax_result)
Choosing the right algorithm for your use case
You need to choose the model algorithm that best suits your use case.
When choosing between SVM, CNN, and Transformers, consider the following:
-
Transformer-based Slate
- Choose when high quality is required and higher computing resources are available.
-
CNN
- Choose when decent size data is available
- Choose if GloVe embedding is available for the required language
- Choose if you have the option between single label versus multi-label
- CNN fine tunes embeddings, so it could give better performance for unknown terms or newer domains.
-
SVM
- Choose if an easier and simpler model is required
- SVM has the fastest training and inference time
- Choose if your data set size is small
If you select SVM, you need to consider the following when choosing between the various implementations of SVM:
- SVMs train multi-label classifiers.
- The larger the number of classes, the longer the training time.
- TF-IDF:
- Choose TF-IDF vectorization with SVM if the data set is small, i.e. has a small number of classes, a small number of examples and shorter text size, for example, sentences containing fewer phrases.
- TF-IDF with SVM can be faster than other algorithms in the classification block.
- Choose TF-IDF if embeddings for the required language are not available.
- USE:
- Choose Universal Sentence Encoder (USE) with SVM if the data set has one or more sentences in input text.
- USE can perform better on data sets where understanding the context of words or sentences is important.
The Ensemble model combines multiple individual (diverse) models together to deliver superior prediction power. Consider the following key data for this model type:
- The ensemble model combines CNN, SVM with TF-IDF and SVM with USE.
- It is the easiest model to use.
- It can give better performance than the individual algorithms.
- It works for all kinds of data sets. However, training time for large datasets (more than 20000 examples) can be high.
- An ensemble model allows you to set weights. These weights decides how the ensemble model combines the results of individual classifiers. Currently, the selection of weights is a heuristics and needs to be set by trial and error. The default weights that are provided in the function itself are a good starting point for the exploration.
Parent topic: Creating your own models