This notebook demonstrates how to run Federated Learning training experiments with Homomorphic Encryption.
The learning goals of this notebook are:
IBM Federated Learning enables you to train a machine learning model across multiple decentralized parties holding local data sets, without sharing the local data sets. Such parties can be for example within an enterprise, within a consortium of enterprises, within multiple data centers or multiple clouds, or on edge devices. This allows to build a collective machine learning model without sharing data between the nodes, therefore addressing data security, privacy, and regulatory compliance requirements, as well as eliminating data movement and its associated costs.
In the federated learning training process, the parties build locally trained machine learning models and send these local models to an aggregator. The aggregator fuses the local models into an aggregated model and sends this model back to the parties to continue with the next round of training.
For additional details see IBM Federated Learning documentation.
IBM Federated Learning uses SSL secured connections between the parties and the aggregator for communicating the machine learning models. In this setting, the aggregator can still see the unencrypted local and aggregated models.
IBM Federated Learning further includes homomorphic encryption capabilities, to enhance the parties’ data privacy and security in settings where the aggregator operates in an environment which is less trusted, and the parties wish to avoid revealing the local models and the aggregated models to the aggregator.
Homomorphic encryption (HE) is a form of encryption that enables performing computations on the encrypted data without decrypting it. The results of the computations remain in encrypted form which, when decrypted, results in an output that is the same as the output produced had the computations been performed on the unencrypted data.
In federated learning, homomorphic encryption enables the parties to homomorphically encrypt their local model updates before sending them to the aggregator. The aggregator sees only the homomorphically encrypted local model updates, and therefore cannot learn anything from this information. Specifically, the aggregator is not able to reverse-engineer the local model updates to discover information on local training data. The aggregator fuses the local model updates in their encrypted form, obtaining an encrypted aggregated model. Then the aggregator sends the encrypted aggregated model to the parties, which decrypt it and continue with the next round of training.
Homomorphic encryption is a form of public key cryptography. It uses a public key for encryption and a private key for decryption. In IBM Federated Learning with homomorphic encryption, the parties (also named “remote training systems”) share the private HE key, and the aggregator has only the public HE key. Each party encrypts its local model update using the public HE key, and sends its encrypted local model update to the aggregator. Since the aggregator does not have the private HE key, it cannot decrypt the encrypted local model updates.
The aggregator uses its public HE key to fuse the encrypted local model updates into a new encrypted aggregated model. This encrypted aggregated model is sent to the parties, which decrypt it using their private HE key, and continue the model training process.
IBM Federated Learning makes it easy to use homomorphic encryption in model training, by specifying simple parameters in the configurations of the aggregator and the parties. IBM Federated Learning includes a mechanism that generates and distributes automatically and securely homomorphic encryption keys among the parties participating in a training experiment.
2. Install the IBM Watson Machine Learning Python client package with homomorphic encryption support, within the Python environment in which this notebook runs. Use the following command within your Python environment:
pip install 'ibm_watsonx_ai[fl-rt23.1-py3.10,fl-crypto]'
.
You can use the installation cell in the next section of this notebook to perform this installation.
This installation is required for any Python environment that will be used for running parties in Federated Learning experiments with homomorphic encryption.
3. In your IBM cloud account:
4. Create a Watson Machine Learning service instance. A free plan is offered. In your Watson Machine Learning service instance:
Install the IBM Watson Machine Learning Python client package with homomorphic encryption support, within the Python environment in which this notebook runs, if this package is not yet installed in this environment.
%pip install --upgrade 'ibm_watsonx_ai[fl-rt23.1-py3.10,fl-crypto]'
The following cell applies base definitions for the notebook.
User action: Before running the following cell, replace the mandatory TBDs in the cell with your information and review the optional TBDs.
import os
import subprocess
import urllib3
import requests
urllib3.disable_warnings()
cmd = subprocess.Popen("pip list | grep -E 'ibm-watsonx-ai|ibm_watsonx_ai'",
shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
wml_installed = len(cmd.communicate()[0]) > 0
if not wml_installed:
raise Exception('ibm-watsonx-ai package must be installed in the environment')
base_dir = os.getcwd() # TBD [optional] A base directory under which the notebook work directory will be created. Default is the current work directory.
nb_dir = os.path.join(base_dir, 'fl_fhe_nb')
data_path = os.path.join(nb_dir, 'data')
model_path = os.path.join(nb_dir, 'model')
crypto_path = os.path.join(nb_dir, 'crypto')
exec_path = os.path.join(nb_dir, 'exec')
if not os.path.exists(data_path):
os.makedirs(data_path)
if not os.path.exists(model_path):
os.makedirs(model_path)
if not os.path.exists(crypto_path):
os.makedirs(crypto_path)
if not os.path.exists(exec_path):
os.makedirs(exec_path)
os.chdir(exec_path)
PROJECT_ID = '' # TBD [mandatory] See the prerequisites section for details.
CLOUD_USERID = '' # TBD [mandatory] See the prerequisites section for details.
IAM_APIKEY = '' # TBD [mandatory] See the prerequisites section for details.
WML_SERVICES_LOCATION = '' # TBD [mandatory] See the prerequisites section for details.
WML_SERVICES_URL = 'https://' + WML_SERVICES_LOCATION + '.ml.cloud.ibm.com'
NUM_RTS = int(3) # TBD [optional] This parameter enables to specify the number of parties for a training experiment.
SW_SPEC_NAME = 'runtime-23.1-py3.10'
HW_SPEC_NAME = 'S'
RSC_TAGS = ['wml_fl_fhe_nb_example']
TIMEOUT_TRAINING_SEC = 600
crypto_file_ext = 'v1'
asym_file_is = crypto_path + "/is_asym_" + crypto_file_ext + ".pem"
cert_file_is = crypto_path + "/is_cert_" + crypto_file_ext + ".pem"
asym_file_sb = crypto_path + "/sb_asym_" + crypto_file_ext + "_"
csr_file_sb = crypto_path + "/sb_csr_" + crypto_file_ext + "_"
cert_file_sb = crypto_path + "/sb_cert_" + crypto_file_ext + "_"
prt_data_file_prefix = 'data_party_'
NUM_MODELS = int(1)
MODEL_NAME = 'pytorch'
MODEL_TYPE = 'pytorch-onnx_2.0'
INIT_MODEL_FILE_NAME = 'pt_mnist_init_model.zip'
INIT_MODEL_URL = 'https://github.com/IBMDataScience/sample-notebooks/raw/master/Files/pt_mnist_init_model.zip'
DATA_HANDLER_FILE_NAME = 'mnist_pytorch_data_handler.py'
DATA_HANDLER_CLASS_NAME = 'MnistPytorchDataHandler'
DATASET_FILE_NAME = 'mnist.npz'
DATASET_URL = 'https://api.dataplatform.cloud.ibm.com/v2/gallery-assets/entries/85ae67d0cf85df6cf114d0664194dc3b/data'
hearbeat_resp = requests.get(WML_SERVICES_URL + "/wml_services/training/heartbeat", verify=False)
print("Heartbeat response %s" % hearbeat_resp.content.decode("utf-8"))
This section creates and activates a WML client, which enables to interact with your WML instance.
from ibm_watsonx_ai import APIClient
wml_credentials = {
"url": WML_SERVICES_URL,
"apikey": IAM_APIKEY
}
wml_client = APIClient(wml_credentials)
wml_client.set.default_project(PROJECT_ID)
The WML assets created in this notebook are initial models and remote training systems.
In this section you can either create new assets, or reuse assets that were created in a previous session of this notebook and not removed.
Initial untrained model assets are required for Federated Learning.
In this notebook, an untrained Pytorch model is used.
For additional details see the documentation on creating initial models.
First, we download a pre-built initial model.
import shutil
print("Downloading initial model")
init_model_file_path = os.path.join(model_path, INIT_MODEL_FILE_NAME)
with requests.get(INIT_MODEL_URL, stream=True) as r:
with open(init_model_file_path, 'wb') as f:
shutil.copyfileobj(r.raw, f)
print('Model stored in: ' + str(init_model_file_path))
print("Done")
Next, we upload the initial model as an asset into the cluster.
print("Storing initial model")
sw_spec_id = wml_client.software_specifications.get_id_by_name(SW_SPEC_NAME)
untrained_model_ids = {}
model_metadata = {
wml_client.repository.ModelMetaNames.NAME: MODEL_NAME,
wml_client.repository.ModelMetaNames.TYPE: MODEL_TYPE,
wml_client.repository.ModelMetaNames.SOFTWARE_SPEC_UID: sw_spec_id,
wml_client.repository.ModelMetaNames.TAGS: RSC_TAGS
}
untrained_model_details = wml_client.repository.store_model(os.path.join(model_path, INIT_MODEL_FILE_NAME), model_metadata)
untrained_model_ids[MODEL_NAME] = wml_client.repository.get_model_id(untrained_model_details)
print('Model id: ' + str(untrained_model_ids[MODEL_NAME]))
print('Done')
A Remote Training System (RTS) asset defines a party that connects to the aggregator for a training experiment.
For additional details see the corresponding documentation page.
print("Creating Remote Training Systems")
remote_training_systems = []
for i in range(NUM_RTS):
rts_metadata = {
wml_client.remote_training_systems.ConfigurationMetaNames.NAME: "Party_"+str(i),
wml_client.remote_training_systems.ConfigurationMetaNames.TAGS: RSC_TAGS,
wml_client.remote_training_systems.ConfigurationMetaNames.ORGANIZATION: {"name" : "IBM", "region": "US"},
wml_client.remote_training_systems.ConfigurationMetaNames.ALLOWED_IDENTITIES: [{"id": CLOUD_USERID, "type": "user"}],
wml_client.remote_training_systems.ConfigurationMetaNames.REMOTE_ADMIN: {"id": CLOUD_USERID, "type":"user"}
}
rts = wml_client.remote_training_systems.store(rts_metadata)
rts_id = wml_client.remote_training_systems.get_id(rts)
print('Remote training system Party_' + str(i) + ' id: ' + str(rts_id))
remote_training_systems.append({'id': rts_id, 'required': True})
print('Done')
Run the following cell if you are reusing WML assets that were created in previous sessions.
This code enables to build internal notebook lists from existing assets. These lists are used in later operations of this notebook.
import json
FORCE_REBUILD_DS = False
print("Models:")
if 'untrained_model_ids' not in globals() or FORCE_REBUILD_DS or \
len(untrained_model_ids) != NUM_MODELS:
untrained_model_ids = {}
load_models_dict = True
else:
load_models_dict = False
models = wml_client.repository.get_model_details(get_all=True)
for m in models['resources']:
md = m['metadata']
if not 'tags' in md or md['tags'] != RSC_TAGS:
continue
if load_models_dict:
untrained_model_ids[md['name']] = md['id']
print('{}: {}'.format(md['name'],md['id']))
print("Remote Training Systems:")
if 'remote_training_systems' not in globals() or FORCE_REBUILD_DS or \
len(remote_training_systems) != NUM_RTS:
remote_training_systems = []
load_rts_lst = True
else:
load_rts_lst = False
rts = wml_client.remote_training_systems.get_details()
for r in rts['resources']:
md = r['metadata']
if not 'tags' in md or md['tags'] != RSC_TAGS:
continue
if load_rts_lst:
remote_training_systems.append({'id': md['id'], 'required': True})
print('{}: {}'.format(md['name'],md['id']))
This section downloads the MNIST data set and splits it into subsets for the parties.
Then, it defines and stores a data handler.
import os
import requests
import numpy as np
import shutil
def load_mnist(normalize=True, download_dir=''):
"""
Download MNIST training data from source used in `keras.datasets.load_mnist`
:param normalize: whether or not to normalize data
:type normalize: bool
:param download_dir: directory to download data
:type download_dir: `str`
:return: 2 tuples containing training and testing data respectively
:rtype (`np.ndarray`, `np.ndarray`), (`np.ndarray`, `np.ndarray`)
"""
local_file = os.path.join(download_dir, DATASET_FILE_NAME)
if not os.path.isfile(local_file):
with requests.get(DATASET_URL, stream=True) as r:
with open(local_file, 'wb') as f:
shutil.copyfileobj(r.raw, f)
with np.load(local_file, allow_pickle=True) as mnist:
x_train, y_train = mnist['x_train'], mnist['y_train']
x_test, y_test = mnist['x_test'], mnist['y_test']
if normalize:
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
np.savez(local_file, x_train=x_train, y_train=y_train,
x_test=x_test, y_test=y_test)
else:
with np.load(local_file, allow_pickle=True) as mnist:
x_train, y_train = mnist['x_train'], mnist['y_train']
x_test, y_test = mnist['x_test'], mnist['y_test']
return (x_train, y_train), (x_test, y_test)
def save_mnist_party_data(nb_dp_per_party, should_stratify, party_folder, dataset_folder):
"""
Saves MNIST party data
:param nb_dp_per_party: the number of data points each party should have
:type nb_dp_per_party: `list[int]`
:param should_stratify: True if data should be assigned proportional to source class distributions
:type should_stratify: `bool`
:param party_folder: folder to save party data
:type party_folder: `str`
:param dataset_folder: folder to save dataset
:type data_path: `str`
:param dataset_folder: folder to save dataset
:type dataset_folder: `str`
"""
if not os.path.exists(dataset_folder):
os.makedirs(dataset_folder)
(x_train, y_train), (x_test, y_test) = load_mnist(download_dir=dataset_folder)
labels, train_counts = np.unique(y_train, return_counts=True)
te_labels, test_counts = np.unique(y_test, return_counts=True)
diff_labels = np.all(np.isin(labels, te_labels))
num_train = np.shape(y_train)[0]
num_test = np.shape(y_test)[0]
num_labels = np.shape(np.unique(y_test))[0]
nb_parties = len(nb_dp_per_party)
if should_stratify:
train_probs = {
label: train_counts[label] / float(num_train) for label in labels}
test_probs = {label: test_counts[label] /
float(num_test) for label in te_labels}
else:
train_probs = {label: 1.0 / len(labels) for label in labels}
test_probs = {label: 1.0 / len(te_labels) for label in te_labels}
for idx, dp in enumerate(nb_dp_per_party):
train_p = np.array([train_probs[y_train[idx]]
for idx in range(num_train)])
train_p /= np.sum(train_p)
train_indices = np.random.choice(num_train, dp, p=train_p)
test_p = np.array([test_probs[y_test[idx]] for idx in range(num_test)])
test_p /= np.sum(test_p)
test_indices = np.random.choice(
num_test, int(num_test / nb_parties), p=test_p)
x_train_pi = x_train[train_indices]
y_train_pi = y_train[train_indices]
x_test_pi = x_test[test_indices]
y_test_pi = y_test[test_indices]
name_file = prt_data_file_prefix + str(idx) + '.npz'
name_file = os.path.join(party_folder, name_file)
np.savez(name_file, x_train=x_train_pi, y_train=y_train_pi,
x_test=x_test_pi, y_test=y_test_pi)
print('Data saved in ' + party_folder)
return
save_mnist_party_data(nb_dp_per_party=[200 for _ in range(NUM_RTS)], should_stratify=False,
party_folder=data_path, dataset_folder=data_path)
print('Done')
This section creates a data handler Python file for the MNIST dataset to train using PyTorch.
For additional details see the corresponding documentation page.
%%writefile mnist_pytorch_data_handler.py
import numpy as np
from ibmfl.data.data_handler import DataHandler
class MnistPytorchDataHandler(DataHandler):
"""
Data handler for the MNIST dataset to train using PyTorch.
"""
def __init__(self, data_config=None):
super().__init__()
self.file_name = None
if data_config is not None:
if 'npz_file' in data_config:
self.file_name = data_config['npz_file']
# Load the datasets.
(self.x_train, self.y_train), (self.x_test, self.y_test) = self.load_dataset()
# Pre-process the datasets.
self.preprocess()
def get_data(self):
"""
Gets pre-process mnist training and testing data.
:return: training data
:rtype: `tuple`
"""
return (self.x_train, self.y_train), (self.x_test, self.y_test)
def load_dataset(self, nb_points=500):
"""
Loads the training and testing datasets from a given local path.
If no local path is provided, it will download the original MNIST \
dataset online, and reduce the dataset size to contain \
500 data points per training and testing dataset.
Because this method
is for testing it takes as input the number of datapoints, nb_points,
to be included in the training and testing set.
:param nb_points: Number of data points to be included in each set if
no local dataset is provided.
:type nb_points: `int`
:return: training and testing datasets
:rtype: `tuple`
"""
try:
data_train = np.load(self.file_name)
x_train = data_train['x_train']
y_train = data_train['y_train']
x_test = data_train['x_test']
y_test = data_train['y_test']
except Exception:
raise IOError('Unable to load training data from path '
'provided in config file: ' +
self.file_name)
return (x_train, y_train), (x_test, y_test)
def preprocess(self):
"""
Preprocesses the training and testing dataset, \
e.g., reshape the images according to self.channels_first; \
convert the labels to binary class matrices.
:return: None
"""
img_rows, img_cols = 28, 28
self.x_train = self.x_train.astype('float32').reshape(self.x_train.shape[0], 1, img_rows, img_cols)
self.x_test = self.x_test.astype('float32').reshape(self.x_test.shape[0], 1,img_rows, img_cols)
self.y_train = self.y_train.astype('int64')
self.y_test = self.y_test.astype('int64')
import shutil
shutil.move(os.path.join('.', DATA_HANDLER_FILE_NAME), os.path.join(data_path, DATA_HANDLER_FILE_NAME))
This section creates the certificate and key files required for running a Federated Learning training experiment with encryption.
Two methods are provided in this section for creating the cryptographic files - using the Python cryptography package, or using openssl. Use either one of these methods.
Homomorphic encryption keys are generated and distributed automatically and securely among the parties for each experiment. Only the parties participating in an experiment have access to the homomorphic encryption private key generated for the experiment.
To facilitate this generation and distribution process, the following steps must be performed before an experiment:
An RSA key pair and certificate for a party must be generated using the following parameters and guidelines:
Each party must be configured with paths to the following files:
Further details on this configuration are provided in the notebook section Launch parties.
In this notebook, we generate and provision self-signed certificates.
import os
import datetime
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography import x509
from cryptography.x509.oid import NameOID
class CryptoRsa():
KEY_SIZE = 4096
PUBLIC_EXPONENT = 65537
CRYPTO_HASH = hashes.SHA256()
def __init__(self):
self.private_key = CryptoRsa.generate_key()
def generate_key():
private_key = rsa.generate_private_key(
public_exponent=CryptoRsa.PUBLIC_EXPONENT,
key_size=CryptoRsa.KEY_SIZE,
)
return private_key
def get_public_key(self, type: str = "obj"):
if self.private_key is None:
raise Exception("self.private_key is None")
if type == "obj":
ret = self.private_key.public_key()
elif type == "pem":
ret = self.private_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
else:
raise Exception("Invalid type=" + repr(type))
return ret
def write_key_file(self, file_path: str):
if self.private_key is None:
raise Exception("self.private_key is None")
pem = self.private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
with open(file_path, "wb") as key_file:
key_file.write(pem)
return
if not os.path.exists(crypto_path):
os.makedirs(crypto_path)
issuer = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, u"US"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, u"California"),
x509.NameAttribute(NameOID.LOCALITY_NAME, u"San Francisco"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, u"Issuer Company"),
x509.NameAttribute(NameOID.COMMON_NAME, u"mysite.com"),
])
subject = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, u"US"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, u"California"),
x509.NameAttribute(NameOID.LOCALITY_NAME, u"San Francisco"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, u"Subject Company"),
x509.NameAttribute(NameOID.COMMON_NAME, u"mysite.com"),
])
issuer_key = CryptoRsa()
issuer_key.write_key_file(asym_file_is)
cert_is = x509.CertificateBuilder().subject_name(
issuer
).issuer_name(
issuer
).public_key(
issuer_key.get_public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.datetime.utcnow()
).not_valid_after(
datetime.datetime.utcnow() + datetime.timedelta(days=1000)
).add_extension(
x509.SubjectAlternativeName([x509.DNSName(u"localhost")]),
critical=False,
).sign(issuer_key.private_key, CryptoRsa.CRYPTO_HASH)
with open(cert_file_is, "wb") as f:
f.write(cert_is.public_bytes(serialization.Encoding.PEM))
for idx in range(NUM_RTS):
asym_file_path = asym_file_sb+str(idx)+".pem"
cert_file_path = cert_file_sb+str(idx)+".pem"
subject_key = CryptoRsa()
subject_key.write_key_file(asym_file_path)
cert_sb = x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
subject_key.get_public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.datetime.utcnow()
).not_valid_after(
datetime.datetime.utcnow() + datetime.timedelta(days=1000)
).add_extension(
x509.SubjectAlternativeName([x509.DNSName(u"localhost")]),
critical=False,
).sign(issuer_key.private_key, CryptoRsa.CRYPTO_HASH)
with open(cert_file_path, "wb") as f:
f.write(cert_sb.public_bytes(serialization.Encoding.PEM))
print('Done')
import os
if not os.path.exists(cert_file_is):
ret = os.system("openssl req -x509 -newkey rsa:4096 -sha256 -days 365 -nodes "
"-subj \"/C=US/ST=California/L=San Francisco/O=Issuer Company/OU=Org/CN=www.iscompany.com\" -keyout " +
str(asym_file_is) + " -out " + str(cert_file_is))
if ret != 0:
raise Exception("openssl for issuer failed: {}".format(ret))
for idx in range(NUM_RTS):
asym_file_path = asym_file_sb+str(idx)+".pem"
csr_file_path = csr_file_sb+str(idx)+".pem"
cert_file_path = cert_file_sb+str(idx)+".pem"
if not os.path.exists(cert_file_path):
ret = os.system("openssl req -newkey rsa:4096 -nodes -subj "
"\"/C=US/ST=California/L=San Francisco/O=SB Company/OU=Org/CN=www.sbcompany.com\" -keyout " +
str(asym_file_path) + " -out " + str(csr_file_path))
if ret != 0:
raise Exception("openssl for subject step 1 failed: {}".format(ret))
ret = os.system("openssl x509 -req -CAcreateserial -CA " + str(cert_file_is) + " -CAkey " + str(asym_file_is) +
" -sha256 -days 365 -in " + str(csr_file_path) + " -out " + str(cert_file_path))
if ret != 0:
raise Exception("openssl for subject step 2 failed: {}".format(ret))
print('Done')
This section launches the Federated Learning aggregator for the experiment.
To run the experiment with homomorphic encryption, the aggregator’s configuration must specify the following fusion type:
"fusion_type": "crypto_iter_avg"
.
The aggregator’s configuration may also include a crypto
object, that specifies the required encryption level. For example:
"crypto": {
"cipher_spec": "encryption_level_1"
}
If this object is not specified then the default of encryption_level_1
is used.
There are four possible encryption levels, ranging from level 1 to level 4. Higher encryption levels increase security and precision, and require higher resource consumption (e.g. computation, memory, network bandwidth). The security level corresponds to the strength of the encryption system, typically measured by the number of operations that an attacker must perform to break the system. The precision level corresponds to the precision of the encryption system's outcomes. Higher precision level means that cryptographic operations are accurate up to a larger number of digits before and after the floating point. Higher precision levels reduce loss of accuracy of the model due to the encryption operations.
Following is a description of the encryption levels:
For additional details on launching the aggregator see the corresponding documentation page.
fl_conf = {
"model": {
"type": MODEL_NAME,
"spec": {
"id": untrained_model_ids[MODEL_NAME]
},
"model_file": "pytorch_sequence.pt"
},
"fusion_type": "crypto_iter_avg",
"crypto": {
"cipher_spec": "encryption_level_1"
},
"epochs": 1,
"rounds": 2,
"metrics": "accuracy",
"remote_training": {
"max_timeout": TIMEOUT_TRAINING_SEC,
"quorum": 1,
"remote_training_systems": remote_training_systems,
},
"software_spec": {
"name": SW_SPEC_NAME
},
"hardware_spec": {
"name": HW_SPEC_NAME
}
}
aggregator_metadata = {
wml_client.training.ConfigurationMetaNames.NAME: 'aggregator_he',
wml_client.training.ConfigurationMetaNames.DESCRIPTION: '',
wml_client.training.ConfigurationMetaNames.TAGS: RSC_TAGS,
wml_client.training.ConfigurationMetaNames.TRAINING_DATA_REFERENCES: [],
wml_client.training.ConfigurationMetaNames.TRAINING_RESULTS_REFERENCE: {
"type": "container",
"name": "outputData",
"connection": {},
"location": {
"path": "."
}
},
wml_client.training.ConfigurationMetaNames.FEDERATED_LEARNING: fl_conf
}
print("Prepared config for aggregator with model type {}".format(MODEL_NAME))
aggregator = wml_client.training.run(aggregator_metadata, asynchronous=True)
print("Created Aggregator")
training_id = wml_client.training.get_id(aggregator)
print("Training id: " + str(training_id))
print ("RTS: " + str(remote_training_systems))
This section launches the Federated Learning parties for the experiment.
To run the experiment with homomorphic encryption, the parties’ configuration must include a crypto
object inside the local_training
object, which specifies the required certificate and key files for the party.
For additional details on launching the parties see the corresponding documentation page.
import os
for idx, prt in enumerate(remote_training_systems):
party_metadata = {
wml_client.remote_training_systems.ConfigurationMetaNames.LOCAL_TRAINING: {
"info": {
"crypto": {
"key_manager": {
"key_mgr_info": {
"distribution": {
"ca_cert_file_path": cert_file_is,
"my_cert_file_path": cert_file_sb+str(idx)+'.pem',
"asym_key_file_path": asym_file_sb+str(idx)+'.pem'
}
}
}
}
}
},
wml_client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: {
"info": {
"npz_file": os.path.join(data_path, prt_data_file_prefix+str(idx)+'.npz')
},
"name": DATA_HANDLER_CLASS_NAME,
"path": os.path.join(data_path, DATA_HANDLER_FILE_NAME)
}
}
print("Connecting party id {} to aggregator id {}, model type {}".format(prt['id'], training_id, MODEL_NAME))
party = wml_client.remote_training_systems.create_party(prt['id'], party_metadata)
party.monitor_logs("ERROR")
party.run(aggregator_id=training_id, asynchronous=True, verify=False)
print("Party {} is running".format(prt['id']))
print('Done')
This section enables to monitor the execution status of the training experiment.
For additional details on monitoring the experiment see the corresponding documentation page.
import time
import json
def monitor_training(training_id):
print('Monitoring training id: {}'.format(training_id))
MAX_ITER = 240
SLP_TIME_SEC = 10
aggregator_status = wml_client.training.get_status(training_id)
aggregator_state = aggregator_status['state']
iter = 0
while iter < MAX_ITER and 'completed' != aggregator_state and 'failed' != aggregator_state and 'canceled' != aggregator_state:
print("Elapsed time: {} seconds, State: {}".format(iter*SLP_TIME_SEC, aggregator_state))
time.sleep(SLP_TIME_SEC)
aggregator_status = wml_client.training.get_status(training_id)
aggregator_state = aggregator_status['state']
iter += 1
if iter >= MAX_ITER:
raise Exception("Training did not finish after {} seconds".format(iter*SLP_TIME_SEC))
print("Final status: " + json.dumps(aggregator_status, indent=4))
if 'training_id' in globals():
monitor_training(training_id)
else:
trn = wml_client.training.get_details(get_all=True)
for t in trn['resources']:
md = t['metadata']
if 'tags' in md and md['tags'] == RSC_TAGS:
monitor_training(md['id'])
Use this section to delete the training jobs, assets, and local files created using this notebook.
print('Removing training jobs')
trn = wml_client.training.get_details(get_all=True)
for t in trn['resources']:
md = t['metadata']
if 'tags' in md and md['tags'] == RSC_TAGS:
wml_client.training.cancel(md['id'], hard_delete=True)
print('Deleted {}: {}'.format(md['name'],md['id']))
print('Done')
print('Removing remote training systems')
rts = wml_client.remote_training_systems.get_details(get_all=True)
for r in rts['resources']:
md = r['metadata']
if 'tags' in md and md['tags'] == RSC_TAGS:
wml_client.repository.delete(md['id'])
print('Deleted {}: {}'.format(md['name'],md['id']))
print('Removing models')
models = wml_client.repository.get_model_details(get_all=True)
for m in models['resources']:
md = m['metadata']
if 'tags' in md and md['tags'] == RSC_TAGS:
wml_client.repository.delete(md['id'])
print('Deleted {}: {}'.format(md['name'],md['id']))
print('Done')
import shutil
shutil.rmtree(nb_dir)
You successfully completed this notebook!
Check out our online documentation and IBM Federated Learning documentation for more tutorials, samples and documentation.
Copyright © IBM Corp. 2022-2024. This notebook and its source code are released under the terms of the MIT License.