This notebook is part of a series of notebooks, designed to guide you on how to utilize watsonx.ai Large Language Models (LLMs) for text classification.
The scenario presented in this notebook assumes you have a few hunderds or thousands of labeled text elements in a multi-class setup, that is, each text element is labeled to exactly one class from a given taxonomy.
Some familiarity with Python is helpful. This notebook runs with python 3.10.
Full fine-tuning of LLMs requires significant resources. To this end, watsonx provides Prompt Tuning (PT), a parameter-efficient tuning method where the pre-trained model parameters are frozen and a soft prompt is trained on the labeled data provided. Compared to prompt-engineering, PT allows the LLMs to learn from much more examples, avoiding the prompt length limitation, and aiming at deeper adaptation to the given data. For more details, please refer to our documentation.
This notebook contains the following parts:
Before you use the sample code in this notebook, you must perform the following setup tasks:
!pip install ibm-watsonx-ai -U -q
!pip install scikit-learn -q
!pip install matplotlib -q
!pip install wget -q
!pip install graphviz -q
!pip install tqdm -q
import os
import getpass
This cell defines the credentials required to work with watsonx API for Foundation Model inferencing.
Action: Provide the IBM Cloud user API url and API key. For details, see documentation.
try:
api_url = os.environ["API_URL"]
except KeyError:
api_url = getpass.getpass("Please enter your watsonx.ai URL domain (hit enter): ")
Please enter your watsonx.ai URL domain (hit enter): ········
try:
api_key = os.environ["API_KEY"]
except KeyError:
api_key = getpass.getpass("Please enter your watsonx.ai API key (hit enter): ")
Please enter your watsonx.ai API key (hit enter): ········
The Foundation Model requires a project id that provides the context for the call. We will try to obtain the id from the project in which this notebook runs. Otherwise, please provide the project id. To find your project id, select the project from the project list, and then take the project id from Manage
->General
->Details
.
try:
project_id = os.environ["PROJECT_ID"]
except KeyError:
project_id = getpass.getpass("Please enter your watsonx.ai Project ID (hit enter): ")
Create an instance of APIClient with authentication details, and set the project id.
credentials={
"url": api_url,
"apikey": api_key
}
from ibm_watsonx_ai import APIClient
client = APIClient(credentials)
client.set.default_project(project_id)
'SUCCESS'
The recommended data size is between a few hundred and a few thousand examples. In case your data is much larger (e.g., 100k examples), you should probably randomly sample from it, to reduce runtime.
For more details read about data formats in our documentation.
Recommendations:
bank account
and my bank account
. Class names which are too similar might confuse the model.To illustrate the use of PT on multi-class data, we will use a subset of the CFPB (Consumer Financial Protection Bureau) dataset (for more details about the dataset click here). This dataset contains millions of real complaints, originally associated with 9 different products (topics). The samples we will use for training and testing in this notebook are downloaded from our git repository and cover a slightly different list of 5 products. We will use separate samples of 1k examples for training and testing.
Alternatively, you can run this notebook with your own data. For this purpose, set the variable use_your_own_data
to True
. In the subsequent cell, set the path to your data in csv format, the text and label column names, the train asset filename and the path to the test file in csv format (if exists).
use_your_own_data = False
import wget
# Define download data function
def download_cfpb(filepath):
url = f"https://raw.github.com/IBM/watson-machine-learning-samples/master/cloud/data/cfpb_complaints/{filepath}"
wget.download(url)
if not use_your_own_data:
print("downloading cfpb data")
filepath = "cfpb_compliants.csv"
download_cfpb(filepath)
text_column = "narrative"
label_column = "product"
asset_train_filename = "CFPB_train.json"
else:
print("defining custom data parameters")
filepath = "" # the path to your data
text_column = "" # the column name for the input texts
label_column = "" # the column name for gold labels
asset_train_filename = "my_train.json" # the name of the file that is created for the train data asset (in json format)
test_filepath = None # optionally, set a path to your test data for evaluation
downloading cfpb data
We preprocess the data. For CFPB, originally the labels contain underscores that we convert to spaces (e.g., credit_reporting
->credit reporting
). If you use your own data, you can update this function with additional preprocessing steps for your dataset.
import pandas as pd
# preprocessing cfpb data - update this function with relevant preprocessing of your data
def preprocess(df, text_column, label_column):
print("preprocessing data")
df = df.rename(columns={text_column: "input", label_column: "output"})
df = df[['input', 'output']]
df.loc[:, "output"] = df["output"].apply(lambda l: l.replace("_", " "))
return df
df = pd.read_csv(filepath)
df = preprocess(df, text_column, label_column)
preprocessing data
We create a data asset for training.
import json
assets = client.data_assets.get_details()['resources']
assets = [a for a in assets if a['metadata']['name'] == asset_train_filename]
def get_demo_train_data(df):
return df.loc[:999, :]
if len(assets) == 0:
print("\ntrain asset does not exist")
if not use_your_own_data:
train_df = get_demo_train_data(df)
else:
train_df = df
data = train_df.to_dict("records")
with open(asset_train_filename, "w") as f:
json.dump(data, f)
train_asset = client.data_assets.create(name=asset_train_filename, file_path=asset_train_filename)
else:
print(f"\ntrain asset {asset_train_filename} exists")
train_asset = assets[0]
asset_uid = client.data_assets.get_id(train_asset)
content = client.data_assets.get_content(asset_uid).decode("utf-8")
train_df = pd.read_json(content, orient="records")
train asset does not exist Creating data asset... SUCCESS
We sort the class names lexicographically (to maintain a consistent order), and check that there are enough examples for each class.
from collections import Counter
class_names = sorted(set(train_df["output"]))
sorted_classes_by_freq = Counter(train_df["output"]).most_common()
print(f"Frequency of classes in the train set: {sorted_classes_by_freq}")
classes_with_less_than_twenty_examples = [c for c in sorted_classes_by_freq if c[1] < 20]
if len(classes_with_less_than_twenty_examples) > 0:
print("Note, the following classes have less than 20 examples in the train set. This may harm their performance:")
print(classes_with_less_than_twenty_examples)
Frequency of classes in the train set: [('debt collection', 236), ('credit reporting', 233), ('mortgages and loans', 213), ('credit card', 190), ('retail banking', 128)]
We define a DataConnection
instance for the training data. This instance will be used as reference when running PT later.
train_asset_id = client.data_assets.get_id(train_asset)
from ibm_watsonx_ai.helpers import DataConnection
data_conn = DataConnection(data_asset_id=train_asset_id)
from ibm_watsonx_ai.experiment import TuneExperiment
experiment = TuneExperiment(credentials, project_id=project_id)
We initialize the verbalizer used in training. We set two verbalizers:
As a subsequent step we will determine which verbalizer to use on this data.
Note: Based on our experiments, the importance of the actual phrasing of the verbalizer decreases when you have a reasonable amount of labeled data (at least 10-20 instances for each class or 300-400 overall). If the size of your data is smaller than that, you may want to consider testing other verbalizers that reflect your domain or type of text. E.g., if your use-case involves classification of legal clauses, the instruction could be classify the legal clause into one the following categories { "class_1", "class_2", ...}
, the input prefix could be legal clause:
, and the output prefix could be category:
.
input_prefix_val = "Input:"
output_prefix_val = "Output:"
instruction_with_class_names_val = 'classify { '+ ", ".join(['"' + c + '"' for c in class_names]) + ' }'
instruction_without_class_names_val = 'classify the following text:'
def get_class_names_verbalizer(class_names):
instruction = instruction_with_class_names_val
input_prefix = input_prefix_val
output_prefix = output_prefix_val
verbalizer = instruction + ' ' + input_prefix + ' {{input}} ' + output_prefix
return verbalizer
class_names_verbalizer = get_class_names_verbalizer(class_names)
print(f"With class names\n{class_names_verbalizer}")
With class names classify { "credit card", "credit reporting", "debt collection", "mortgages and loans", "retail banking" } Input: {{input}} Output:
def get_no_class_names_verbalizer():
instruction = instruction_without_class_names_val
input_prefix = input_prefix_val
output_prefix = output_prefix_val
verbalizer = instruction + ' ' + input_prefix + ' {{input}} ' + output_prefix
return verbalizer
no_class_names_verbalizer = get_no_class_names_verbalizer()
print(f"Without class names\n{no_class_names_verbalizer}")
Without class names classify the following text: Input: {{input}} Output:
The default token limit for inputs (verbalizer + text) for tuning is 256 tokens. Thus, the actual input might not fit entirely per this limit if we have many classes in the dataset, and/or the input texts are relatively long.
The following cell estimates the length of the verbalizer in tokens.
If the verbalizer with all class names takes up more than 256 tokens, it means they won't fit in the input and the tuning may produce poor performance. In this case, we fallback to the verbalizer without the class names.
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
generate_params_for_verbalizer = {
GenParams.MAX_NEW_TOKENS: 1,
GenParams.DECODING_METHOD: 'greedy',
GenParams.MIN_NEW_TOKENS: 1,
GenParams.TRUNCATE_INPUT_TOKENS: 4096,
GenParams.RETURN_OPTIONS: {'input_text': False, 'generated_tokens': False, 'input_tokens': True, 'token_logprobs': False, 'token_ranks': False, 'top_n_tokens': 0}
}
from ibm_watsonx_ai.foundation_models import ModelInference
from ibm_watsonx_ai.foundation_models.utils.enums import ModelTypes
base_model_for_verbalizer = ModelInference(
model_id=ModelTypes.FLAN_T5_XL,
params=generate_params_for_verbalizer,
api_client=client
)
def get_verbalizer_num_tokens(verbalizer):
batch = [verbalizer.replace("{{input}}", "")]
predictions = base_model_for_verbalizer.generate_text(prompt=batch, raw_response=True)
verbalizer_num_tokens = predictions[0]['results'][0]['input_token_count']
return verbalizer_num_tokens
token_limit = 256
verbalizer_token_count = get_verbalizer_num_tokens(class_names_verbalizer)
print(f"Estimated verbalizer length: {verbalizer_token_count}")
if verbalizer_token_count >= token_limit:
verbalizer_token_count = get_verbalizer_num_tokens(no_class_names_verbalizer)
verbalizer = no_class_names_verbalizer
print("The verbalizer itself might be longer than the permitted token limit. We fallback to the generic verbalizer.")
class_names_in_train_verbalizer = False
else:
verbalizer = class_names_verbalizer
print("Using verbalizer with class names")
class_names_in_train_verbalizer = True
Estimated verbalizer length: 40 Using verbalizer with class names
We then estimate the average length of texts to assess the difference between their length and the space they have in the prompt. If this difference is large - that is, the texts will be heavily trimmed - you may want to consider using the verbalizer without the class names to allow more space for the texts.
Note: texts which will not fit the token limit are trimmed automatically, keeping the prefix. You may want to consider trimming the texts in advance to keep the most important parts of your data, e.g., if in your data the end of the text is more important than the opening.
# estimating text token length
texts = train_df["input"].tolist()
text_token_count = [len(t.split()) * 1.5 for t in texts]
avg_text_length = sum(text_token_count)/len(text_token_count)
print(f"Estimated number tokens in verbalizer remaining for texts: {token_limit-verbalizer_token_count}")
print(f"Estimated average length of texts in the train set: {'{:.2f}'.format(avg_text_length)}")
if token_limit - verbalizer_token_count < avg_text_length:
print(f"If the number of tokens remaining for texts is much smaller than their actual length, this might harm performance.")
Estimated number tokens in verbalizer remaining for texts: 216 Estimated average length of texts in the train set: 151.98
The process of PT trains a soft prompt
- a list of 100 token embeddings - instead of the entire model. To achieve satisfactory results with the model tuned in this notebook (google/flan-t5-xl
), it is recommended to initialize the soft prompt tokens with the following:
This initialization presumably allows a more semantically meaningful starting point for the soft prompt embeddings. We set the inialization text in the soft_prompt_initialization_text
parameter.
class_names_order_by_prior = [c[0] for c in sorted_classes_by_freq]
soft_prompt_initialization_text = "Classify the text into one of the following categories: " + ", ".join(class_names_order_by_prior)
print(f"The soft prompt initialization text: \n{soft_prompt_initialization_text}\n")
The soft prompt initialization text: Classify the text into one of the following categories: debt collection, credit reporting, mortgages and loans, credit card, retail banking
We define the prompt_tuner
parameters, including the verbalizer and the initialization text. Tuning is currently available for google/flan-t5-xl
.
Note: We set the number of epochs to 50 (num_epochs
parameter), as this usually provided the best results in our experiments. You may want to consider decreasing the number of epochs to save on run-time, especially if your dataset contains more than 1k elements, with potential decrease in performance.
from ibm_watsonx_ai.foundation_models.utils.enums import PromptTuningInitMethods
prompt_tuner = experiment.prompt_tuner(name="sample SDK run auto_update True on CFPB",
task_id=experiment.Tasks.CLASSIFICATION,
base_model=ModelTypes.FLAN_T5_XL,
accumulate_steps=16,
batch_size=16,
learning_rate=0.3,
max_input_tokens=token_limit,
max_output_tokens=20,
num_epochs=50,
tuning_type=experiment.PromptTuningTypes.PT,
verbalizer=verbalizer,
auto_update_model=True,
init_method=PromptTuningInitMethods.TEXT,
init_text=soft_prompt_initialization_text
)
prompt_tuner.get_params()
{'base_model': {'model_id': 'google/flan-t5-xl'}, 'accumulate_steps': 16, 'batch_size': 16, 'init_method': 'text', 'init_text': 'Classify the text into one of the following categories: debt collection, credit reporting, mortgages and loans, credit card, retail banking', 'learning_rate': 0.3, 'max_input_tokens': 256, 'max_output_tokens': 20, 'num_epochs': 50, 'task_id': 'classification', 'tuning_type': 'prompt_tuning', 'verbalizer': 'classify { "credit card", "credit reporting", "debt collection", "mortgages and loans", "retail banking" } Input: {{input}} Output:', 'name': 'sample SDK run auto_update True on CFPB', 'description': 'Prompt tuning with SDK', 'auto_update_model': True, 'group_by_name': False}
We run a prompt tuning process of foundation models on top of the training data referenced by DataConnection (tuning may take some time).
By setting the background_mode
parameter to True
, the prompt tuning process will run in the background.
tuning_details = prompt_tuner.run(training_data_references=[data_conn], background_mode=True)
PT over 1k examples of the CFPB data takes about 50 minutes. If background_mode
is set to True
, we use the following utility function to probe the status of training. Run this cell, and it will finish when training is completed.
print(tuning_details)
{'metadata': {'created_at': '2024-01-30T17:16:18.891Z', 'description': 'Prompt tuning with SDK', 'id': 'c9eccc2d-b354-463b-a607-258a9b7eff3c', 'modified_at': '2024-01-30T17:16:18.891Z', 'name': 'sample SDK run auto_update True on CFPB', 'project_id': '7d0667b5-00b6-4f34-b002-13a1f5db1676', 'tags': ['prompt_tuning', 'wx_prompt_tune.07b351a6-7899-4aae-8ad6-65fd68706591']}, 'entity': {'auto_update_model': True, 'description': 'Prompt tuning with SDK', 'name': 'sample SDK run auto_update True on CFPB', 'project_id': '7d0667b5-00b6-4f34-b002-13a1f5db1676', 'prompt_tuning': {'accumulate_steps': 16, 'base_model': {'model_id': 'google/flan-t5-xl'}, 'batch_size': 16, 'init_method': 'text', 'init_text': 'Classify the text into one of the following categories: debt collection, credit reporting, mortgages and loans, credit card, retail banking', 'learning_rate': 0.3, 'max_input_tokens': 256, 'max_output_tokens': 20, 'num_epochs': 50, 'num_virtual_tokens': 100, 'task_id': 'classification', 'tuning_type': 'prompt_tuning', 'verbalizer': 'classify { "credit card", "credit reporting", "debt collection", "mortgages and loans", "retail banking" } Input: {{input}} Output:'}, 'results_reference': {'connection': {}, 'location': {'assets_path': 'default_tuning_output/c9eccc2d-b354-463b-a607-258a9b7eff3c/assets', 'path': 'default_tuning_output', 'training': 'default_tuning_output/c9eccc2d-b354-463b-a607-258a9b7eff3c', 'training_status': 'default_tuning_output/c9eccc2d-b354-463b-a607-258a9b7eff3c/training-status.json'}, 'type': 'container'}, 'status': {'state': 'pending'}, 'tags': ['prompt_tuning', 'wx_prompt_tune.07b351a6-7899-4aae-8ad6-65fd68706591'], 'training_data_references': [{'connection': {}, 'location': {'href': '/v2/assets/bb247953-917f-4bc6-82e8-4f61ea9e53e7', 'id': 'bb247953-917f-4bc6-82e8-4f61ea9e53e7'}, 'type': 'data_asset'}]}}
import time
model_id = tuning_details['metadata']['id']
def wait_till_model_is_ready(tuning_details):
while True:
status = prompt_tuner.get_run_status()
print(f'Model {model_id} is {status}')
if status == 'completed':
break
elif status not in ['initializing', 'pending', 'running']:
raise ValueError(f'Unexpected tune status: {status}')
time.sleep(60)
wait_till_model_is_ready(tuning_details)
Model c9eccc2d-b354-463b-a607-258a9b7eff3c is pending Model c9eccc2d-b354-463b-a607-258a9b7eff3c is pending Model c9eccc2d-b354-463b-a607-258a9b7eff3c is pending Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is running Model c9eccc2d-b354-463b-a607-258a9b7eff3c is completed
prompt_tuner.summary()
Enhancements | Base model | Auto store | Epochs | loss | |
---|---|---|---|---|---|
Model Name | |||||
model_c9eccc2d-b354-463b-a607-258a9b7eff3c | [prompt_tuning] | google/flan-t5-xl | True | 50 | 0.15961 |
Plot the learning curves. The left most curve shows the loss over the epochs. A good tuning process is one where the loss gradually decreases as training progresses.
prompt_tuner.plot_learning_curve()
model_id = prompt_tuner.get_run_details()['entity']['model_id']
print(model_id)
5b538d7e-9a80-46ab-83ba-2215401f8b2c
from datetime import datetime
meta_props = {
client.deployments.ConfigurationMetaNames.NAME: "PT DEPLOYMENT SDK - project",
client.deployments.ConfigurationMetaNames.ONLINE: {},
client.deployments.ConfigurationMetaNames.SERVING_NAME : f"pt_sdk_deployment_{datetime.utcnow().strftime('%Y_%m_%d_%H%M%S')}"
}
deployment_details = client.deployments.create(model_id, meta_props)
####################################################################################### Synchronous deployment creation for uid: '5b538d7e-9a80-46ab-83ba-2215401f8b2c' started ####################################################################################### initializing. ready ------------------------------------------------------------------------------------------------ Successfully finished deployment creation, deployment_uid='fd95312f-38db-47e0-8dea-84404dc55895' ------------------------------------------------------------------------------------------------
We store the deployment id which will be used for inference.
deployment_details
deployment_id = client.deployments.get_id(deployment_details)
print(deployment_id)
fd95312f-38db-47e0-8dea-84404dc55895
We define the generation parameters.
greedy
decoding method because we want the model to return the most probable class name for the given text.TRUNCATE_INPUT_TOKENS
to 3975 to trim inputs which are longer than that.generate_params = {
GenParams.MAX_NEW_TOKENS: 20,
GenParams.DECODING_METHOD: 'greedy',
GenParams.MIN_NEW_TOKENS: 1,
GenParams.TRUNCATE_INPUT_TOKENS: 3975
}
We define an instance of ModelInference
using the deployment_id
and the generation parameters.
tuned_model = ModelInference(
deployment_id=deployment_id,
params=generate_params,
api_client=client
)
If you chose to run with the demo data, we take the test data from the original CFPB dataframe. Otherwise, the test data is taken from the path specified in the variable test_filepath
in the cell Use your own data
.
Note: We perform the same preprocessing for the test set as done for the train set. It is assumed they share identical column names.
def get_demo_test_data(df):
return df.loc[1000:1999, :]
def get_test_data(filepath):
return pd.read_csv(filepath)
if not use_your_own_data:
test_df = get_demo_test_data(df)
else:
if test_filepath is None:
raise Exception("you need set a path to a test file to continue (in cell \"Use your own data\")")
test_df = get_test_data(test_filepath)
test_df = preprocess(test_df, text_column, label_column)
We check there are enough examples for each class in the test set.
eval_classes = Counter(test_df["output"]).most_common()
classes_with_less_than_ten_examples = [c for c in eval_classes if c[1] < 10]
if len(classes_with_less_than_ten_examples) > 0:
print("Note, the following classes have less than 10 examples in the test set. It could be hard to deduce how well the model performs on them.")
print(classes_with_less_than_ten_examples)
In case the verbalizer used in tuning contained the class names, there is no need to add them in inference. On the other hand, if the verbalizer without the class names was selected for tuning, we add them in inference.
if class_names_in_train_verbalizer:
prompts_batch = test_df["input"].tolist()
else:
prompts_batch = [class_names_verbalizer.replace("{{input}}", text) for text in test_df["input"]]
labels = test_df["output"].tolist()
We infer the tuned model over the test data.
Occasionally, the output of the model may yield class names which are not part of the original class list. From our experiments, such hallucinations usually occur for ~1% of the examples on average.
For this purpose we use a utility function from the difflib
package to find the most similar class to the generated text. If none is found, we use the class with the highest prior in the train set.
Note: In case inference fails, try waiting a few seconds and running the cell again. If it fails, you may want to re-run the Deploy
section and then try inference again.
import time
import difflib
start = time.time()
print(f"Running inference on {len(prompts_batch)} examples")
predictions = tuned_model.generate_text(prompt=prompts_batch)
predictions = [(difflib.get_close_matches(res, class_names) + [class_names_order_by_prior[0]])[0] for res in predictions]
end = time.time()
print(f"elapsed time: {end - start}")
Running inference on 1000 examples elapsed time: 59.58870577812195
We evaluate the predictions with respect to the labels. We calculate accuracy, micro-f1, macro-f1 and f1 per class.
from sklearn.metrics import accuracy_score, f1_score
def get_labels_in_binary_form(label, class_names):
return [1 * (cls == label) for cls in class_names]
def get_metrics(labels, predictions):
binary_labels = [get_labels_in_binary_form(l, class_names) for l in labels]
binary_preds = [get_labels_in_binary_form(p, class_names) for p in predictions]
f1_per_class = f1_score(binary_labels, binary_preds, average=None)
return {
f'accuracy': f'{accuracy_score(binary_labels, binary_preds)*100:.2f}',
f'micro_f1': f'{f1_score(binary_labels, binary_preds, average="micro")*100:.2f}',
f'macro_f1': f'{f1_score(binary_labels, binary_preds, average="macro")*100:.2f}',
f'f1_per_class': (pd.Series(f1_per_class, index=class_names)*100.).round(2),
}
performance = get_metrics(labels, predictions)
print(performance)
{'accuracy': '81.40', 'micro_f1': '81.40', 'macro_f1': '81.61', 'f1_per_class': credit card 78.65 credit reporting 79.62 debt collection 74.87 mortgages and loans 88.26 retail banking 86.67 dtype: float64}
This cell is optional, and is used to compare the tuned model to the same model in a zero-shot setup. On the dataset we are using in this notebook, the gain with the tuned model is around ~21 macro-f1 points.
We initialize the ModelInference
class with base model, and infer over the test data.
We need to provide the base model with the zero-shot prompt for the task. We use the prompt that is defined in the verbalizer
.
generate_params = {
GenParams.MAX_NEW_TOKENS: 20,
GenParams.DECODING_METHOD: 'greedy',
GenParams.MIN_NEW_TOKENS: 1,
GenParams.TRUNCATE_INPUT_TOKENS: 4096,
}
base_model = ModelInference(
model_id=ModelTypes.FLAN_T5_XL,
params=generate_params,
api_client=client
)
base_prompts_batch = [class_names_verbalizer.replace("{{input}}", text) for text in test_df["input"]]
base_predictions = base_model.generate_text(prompt=base_prompts_batch)
base_predictions = [(difflib.get_close_matches(res, class_names) + [class_names_order_by_prior[0]])[0] for res in base_predictions]
performance = get_metrics(labels, base_predictions)
print(performance)
{'accuracy': '61.30', 'micro_f1': '61.30', 'macro_f1': '60.98', 'f1_per_class': credit card 72.35 credit reporting 61.64 debt collection 40.63 mortgages and loans 56.00 retail banking 74.29 dtype: float64}
It can be helpful to investigate common cases where the model confuses between one class to another. Such pairs of classes could indicate that the difference between them is too subtle for the model to distiguish, or that the labeled data is noisy.
For this purpose, we compute a confusion matrix using model predictions versus the gold labels over the test set.
Note: It is assumed the set of labels in the train and test splits are identical.
from sklearn.metrics import confusion_matrix
evaluation_df = pd.DataFrame({'label': labels, 'prediction': predictions})
def compute_confusition_matrix(evaluation_df):
labels = evaluation_df.label.tolist()
labels_sorted = sorted(list(set(labels)))
print(labels_sorted)
predictions = evaluation_df.prediction.tolist()
cm = confusion_matrix(y_true=labels, y_pred=predictions, labels=labels_sorted)
df = pd.DataFrame(cm)
df.columns = labels_sorted
df.index = labels_sorted
print(df)
compute_confusition_matrix(evaluation_df)
['credit card', 'credit reporting', 'debt collection', 'mortgages and loans', 'retail banking'] credit card credit reporting debt collection \ credit card 140 11 3 credit reporting 10 207 24 debt collection 13 37 149 mortgages and loans 4 12 6 retail banking 9 4 0 mortgages and loans retail banking credit card 6 20 credit reporting 6 2 debt collection 14 3 mortgages and loans 188 2 retail banking 0 130
We compared the performance achieved by running PT in the setup described in this notebook to zero-shot classification using Flan-t5-xl / xxl on 9 benchmark datasets (including one dataset which appears in two versions). We report macro-f1
results. As depicted in the table below, PT results provide a significant improvement in almost all datasets.
Estimate how well classes and texts fit in the input
). For zero-shot we use the verbalizer with class names throughout.Dataset | Number of classes | PT verbalizer | Flan-t5-xl PT (this notebook) | Flan-t5-xl zero-shot | Flan-t5-xxl zero-shot |
---|---|---|---|---|---|
Sentiment Classification | 3 | with classes | 82.53 | 74.64 | 75.68 |
CFPB (notebook version) | 5 | with classes | 81.61 | 60.97 | 67.26 |
CFPB | 9 | with classes | 75.2 | 72.54 | 74.34 |
Email Classification | 9 | with classes | 59.44 | 9.86 | 19.7 |
20 Newsgroup | 20 | with classes | 80.73 | 62.1 | 59.44 |
News Classification | 40 | with classes | 39.74 | 27.82 | 31.83 |
Arg-Topic | 71 | without classes | 97.77 | 95.07 | 96.65 |
banking77 | 77 | without classes | 63.39 | 51.3 | 55.51 |
LEDGAR | 98 | without classes | 60.26 | 40.04 | 38.22 |
clinc150 | 150 | without classes | 74.45 | 61.78 | 70.47 |
*Mean* | 71.53 | 55.65 | 58.91 |
Congratulations, you have successfully trained a model using watsonx Prompt Tuning!
Check out our Online Documentation for more samples, tutorials, documentation, how-tos, and blog posts.
Author: Shai Gretz, LM Utilization Team, IBM Research.
Copyright © 2023 IBM. This notebook and its source code are released under the terms of the MIT License.