This notebook contains the steps and code to demonstrate support of prompt tuning in watsonx.
Some familiarity with Python is helpful. This notebook uses Python 3.10.
The goal of this notebook is to demonstrate how to:
This notebook contains the following parts:
Before you use the sample code in this notebook, you must perform the following setup tasks:
datasets
and dependecies¶!pip install ibm-watsonx-ai==1.0.335 | tail -n 1
!pip install scikit-learn | tail -n 1
!pip install matplotlib | tail -n 1
!pip install wget | tail -n 1
This cell defines the credentials required to work with watsonx API for Foundation Model inferencing.
Action: Provide the IBM Cloud user API key. For details, see documentation.
import getpass
credentials = {
"url": "https://us-south.ml.cloud.ibm.com",
"apikey": getpass.getpass("Please enter your WML api key (hit enter): ")
}
The Foundation Model requires project id that provides the context for the call. We will obtain the id from the project in which this notebook runs. Otherwise, please provide the project id.
import os
try:
project_id = os.environ["PROJECT_ID"]
except KeyError:
project_id = input("Please enter your project_id (hit enter): ")
Create an instance of APIClient with authentication details.
from ibm_watsonx_ai import APIClient
client = APIClient(credentials)
To be able to interact with all resources available in Watson Machine Learning, you need to set project_id which you will be using.
client.set.default_project(project_id)
'SUCCESS'
This example uses the Consumer Financial Protection Bureau training dataset.
filename = 'cfpb_train.json'
Download training data from git repository and create data assets.
import wget
import os
url = "https://raw.github.com/IBM/watson-machine-learning-samples/master/cloud/data/prompt_tuning/cfpb_train.json"
if not os.path.isfile(filename):
wget.download(url)
asset_details = client.data_assets.create(name=filename, file_path=filename)
Creating data asset... SUCCESS
asset_id = client.data_assets.get_id(asset_details)
Define connection information to training data.
from ibm_watsonx_ai.helpers import DataConnection
data_conn = DataConnection(data_asset_id=asset_id)
from ibm_watsonx_ai.experiment import TuneExperiment
experiment = TuneExperiment(credentials, project_id=project_id)
All avaliable tasks are presented under Tasks
Enum
{task.name: task.value for task in experiment.Tasks}
{'QUESTION_ANSWERING': 'question_answering', 'SUMMARIZATION': 'summarization', 'RETRIEVAL_AUGMENTED_GENERATION': 'retrieval_augmented_generation', 'CLASSIFICATION': 'classification', 'GENERATION': 'generation', 'CODE_GENERATION_AND_CONVERSION': 'code', 'EXTRACTION': 'extraction'}
Define prompt_tuner
parameters
prompt_tuner = experiment.prompt_tuner(name="sample SDK run auto_update True",
task_id=experiment.Tasks.CLASSIFICATION,
base_model='google/flan-t5-xl',
accumulate_steps=32,
batch_size=16,
learning_rate=0.2,
max_input_tokens=256,
max_output_tokens=20,
num_epochs=6,
tuning_type=experiment.PromptTuningTypes.PT,
verbalizer="Including narratives choice the best match product with the items from the list: 'credit_card', 'debt_collection', 'mortgages_and_loans', 'credit_reporting', 'retail_banking'. Input: {{input}} Output: ",
auto_update_model=True
)
You can review previous set parameters
prompt_tuner.get_params()
{'base_model': {'model_id': 'google/flan-t5-xl'}, 'accumulate_steps': 32, 'batch_size': 16, 'learning_rate': 0.2, 'max_input_tokens': 256, 'max_output_tokens': 20, 'num_epochs': 6, 'task_id': 'classification', 'tuning_type': 'prompt_tuning', 'verbalizer': "Including narratives choice the best match product with the items from the list: 'credit_card', 'debt_collection', 'mortgages_and_loans', 'credit_reporting', 'retail_banking'. Input: {{input}} Output: ", 'name': 'sample SDK run auto_update True', 'description': 'Prompt tuning with SDK', 'auto_update_model': True, 'group_by_name': False}
Run a prompt tuning process of foundation model on top of the training data referenced by DataConnection (tuning may take some time).
By changing the background_mode
parameter to True
, the prompt tuning process will run in the background.
tuning_details = prompt_tuner.run(training_data_references=[data_conn], background_mode=False)
############################################## Running 'b4300476-f8dd-4684-81ec-290c15d1ed95' ############################################## pending........ running..................................................................................... completed Training of 'b4300476-f8dd-4684-81ec-290c15d1ed95' finished successfully.
Check status/state of initialized Prompt Tuning run if ran in background mode or when process finish if background mode is off.
prompt_tuner.get_run_status()
'completed'
Get prompt tuning run details.
prompt_tuner.get_run_details()
{'metadata': {'created_at': '2023-12-06T14:31:27.126Z', 'description': 'Prompt tuning with SDK', 'id': 'b4300476-f8dd-4684-81ec-290c15d1ed95', 'modified_at': '2023-12-06T14:41:17.834Z', 'name': 'sample SDK run auto_update True', 'project_id': '70b1b2eb-b820-41ee-bcfb-76d8d83ee5c0', 'tags': ['prompt_tuning', 'wx_prompt_tune.2e31c747-cb30-49ce-9cff-ccad973ad9d2']}, 'entity': {'auto_update_model': True, 'description': 'Prompt tuning with SDK', 'model_id': 'a0613c34-faa5-49c0-86b5-9d3357293ce1', 'name': 'sample SDK run auto_update True', 'project_id': '70b1b2eb-b820-41ee-bcfb-76d8d83ee5c0', 'prompt_tuning': {'accumulate_steps': 32, 'base_model': {'model_id': 'google/flan-t5-xl'}, 'batch_size': 16, 'init_method': 'random', 'learning_rate': 0.2, 'max_input_tokens': 256, 'max_output_tokens': 20, 'num_epochs': 6, 'num_virtual_tokens': 100, 'task_id': 'classification', 'tuning_type': 'prompt_tuning', 'verbalizer': "Including narratives choice the best match product with the items from the list: 'credit_card', 'debt_collection', 'mortgages_and_loans', 'credit_reporting', 'retail_banking'. Input: {{input}} Output: "}, 'results_reference': {'connection': {}, 'location': {'path': 'default_tuning_output', 'training': 'default_tuning_output/b4300476-f8dd-4684-81ec-290c15d1ed95', 'training_status': 'default_tuning_output/b4300476-f8dd-4684-81ec-290c15d1ed95/training-status.json', 'model_request_path': 'default_tuning_output/b4300476-f8dd-4684-81ec-290c15d1ed95/assets/b4300476-f8dd-4684-81ec-290c15d1ed95/resources/wml_model/request.json', 'assets_path': 'default_tuning_output/b4300476-f8dd-4684-81ec-290c15d1ed95/assets'}, 'type': 'container'}, 'status': {'completed_at': '2023-12-06T14:41:17.672Z', 'state': 'completed'}, 'tags': ['prompt_tuning', 'wx_prompt_tune.2e31c747-cb30-49ce-9cff-ccad973ad9d2'], 'training_data_references': [{'connection': {}, 'location': {'href': '/v2/assets/606d9ab5-4672-4807-83f0-33615279392d', 'id': '606d9ab5-4672-4807-83f0-33615279392d'}, 'type': 'data_asset'}]}}
Let's summarize the run.
prompt_tuner.summary()
Enhancements | Base model | Auto store | Epochs | loss | |
---|---|---|---|---|---|
Model Name | |||||
model_b4300476-f8dd-4684-81ec-290c15d1ed95 | [prompt_tuning] | google/flan-t5-xl | True | 6 | 0.356224 |
Plot learning curves.
prompt_tuner.plot_learning_curve()
You can specify model_id
from tuning details.
model_id = None
if 'model_id' in tuning_details.get('entity', {}):
model_id = tuning_details['entity']['model_id']
model_id
'a0613c34-faa5-49c0-86b5-9d3357293ce1'
Create online deployment for published model.
from datetime import datetime
meta_props = {
client.deployments.ConfigurationMetaNames.NAME: "PT DEPLOYMENT SDK - project",
client.deployments.ConfigurationMetaNames.ONLINE: {},
client.deployments.ConfigurationMetaNames.SERVING_NAME : f"pt_sdk_deployment_{datetime.utcnow().strftime('%Y_%m_%d_%H%M%S')}"
}
deployment_details = client.deployments.create(model_id, meta_props)
####################################################################################### Synchronous deployment creation for uid: 'a0613c34-faa5-49c0-86b5-9d3357293ce1' started ####################################################################################### initializing ready ------------------------------------------------------------------------------------------------ Successfully finished deployment creation, deployment_uid='5217e49f-3c61-4f06-ac67-ca1b2de4bdf2' ------------------------------------------------------------------------------------------------
Additionally you can get deployment details by printing deployment_details
deployment_details
You can specify deployment_id
from deployment details.
deployment_id = deployment_details['metadata']['id']
deployment_id
'5217e49f-3c61-4f06-ac67-ca1b2de4bdf2'
watsonx.ai
¶Provide a set of model parameters.
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
generate_params = {
GenParams.MAX_NEW_TOKENS: 20,
GenParams.STOP_SEQUENCES: ["\n"]
}
Initialize the ModelInference
class.
from ibm_watsonx_ai.foundation_models import ModelInference
tuned_model = ModelInference(
deployment_id=deployment_id,
params=generate_params,
api_client=client
)
Alternatively you can use credentials
and project_id
to initialize the ModelInference
class.
tuned_model = ModelInference(
deployment_id=deployment_id,
params=generate_params,
credentials=credentials,
project_id=project_id
)
Get deployment model inference details.
tuned_model.get_details()
Analyze the product class for a sample prompt.
response = tuned_model.generate_text(prompt="Including narratives choice the best match product with the items from the list: 'credit_card', 'debt_collection', 'mortgages_and_loans', 'credit_reporting', 'retail_banking'.\nComment: hi landed job reside ca needed room rent found place rent paid deposit dollar however position going didnt work longer needed rent place bay asked landlord refund security deposit refused told called back wellsfargo disputed transaction recently noticed card reversal checking account got charged amount dollar called bank werent able refund money also emailed landlord asking refund money ten day passed still response hope cfpb take action successfully resolve issue thank\nProduct:\n")
response
'credit_card'
Download and prepare the cfpb_test
dataset.
import pandas as pd
filename = 'cfpb_test.json'
url = "https://raw.github.com/IBM/watson-machine-learning-samples/master/cloud/data/prompt_tuning/cfpb_test.json"
if not os.path.isfile(filename):
wget.download(url)
data = pd.read_json(filename)
prompts = list(data.input)
products = list(data.output)
prompts_batch = ["\n".join([prompt]) for prompt in prompts]
tuned_model_results = tuned_model.generate_text(prompt=prompts_batch)
from sklearn.metrics import accuracy_score
print(f'accuracy_score: {accuracy_score(products, tuned_model_results)}')
accuracy_score: 0.6266666666666667
Initialize the ModelInference
class with base model.
base_model = ModelInference(
model_id='google/flan-t5-xl',
params=generate_params,
api_client=client
)
base_model_results = base_model.generate_text(prompt=prompts_batch)
from sklearn.metrics import accuracy_score
print(f'base model accuracy_score: {accuracy_score(products, base_model_results)}')
base model accuracy_score: 0.5333333333333333
You successfully completed this notebook!.
You learned how to use prompt tuning in watsonx for analyze the product classes for a sample prompts.
Check out our Online Documentation for more samples, tutorials, documentation, how-tos, and blog posts.
Mateusz Szewczyk, Software Engineer at Watson Machine Learning.
Copyright © 2023 IBM. This notebook and its source code are released under the terms of the MIT License.