0 / 0
Choosing your framework, fusion method, and hyperparameters
Choosing your framework, fusion method, and hyperparameters

Choosing your framework, fusion method, and hyperparameters

From your Federated Learning experiment, you must choose the framework and fusion method for the Federated Learning model. The choices also impact your hyperparameter options.

Frameworks, model types, and fusion methods

This table shows the number of frameworks available, as well as their model types and supported hyperparameters. For definitions on the hyperparameters, see the next table.

Frameworks Model Type Description Fusion Method Description Hyperparameters
TensorFlow 2
TensorFlow 2.x based platform used to build neural networks and more. Requires a pre-existing untrained model that is provided by your team's data scientist in SavedModel format. Browse and upload the Model file and give it a name. See Tensorflow 2 model configuration for examples of Tensorflow 2 model files.
Any N/A Simple Avg Simplest aggregation that is used as a baseline where all parties' model updates are equally weighted. - Rounds
- Termination accuracy (Optional)
- Quorum (Optional)
- Max Timeout (Optional)
Weighted Avg Weights the average of updates based on the number of each party sample. Use with training data sets of widely differing sizes. - Rounds
- Termination accuracy (Optional)
- Quorum (Optional)
- Max Timeout (Optional)
Scikit-learn
The Scikit-learn Python Machine Learning package that is used for predictive data analysis and more. Requires pre-existing untrained model that is provided by your team's data scientist and configuration to save the model as a pickle file.
See Scikit-learn model configuration for examples of Scikit-learn model files. You must modify the model configuration to save the model in Federated Learning as it is not a feature by default.
Classification The classification machine learning learning algorithm with Scikit-learn. Note: You must specify the class label for IBM Federated Learning classification models. Simple Avg Simplest aggregation that is used as a baseline where all parties' model updates are equally weighted. - Rounds
- Termination accuracy (Optional)
Weighted Avg Weights the average of updates based on the number of each party sample. Use with training data sets of widely differing sizes. - Rounds
- Termination accuracy (Optional)
Regression The regression machine learning learning algorithm with Scikit-learn. Simple Avg Simplest aggregation that is used as a baseline where all parties' model updates are equally weighted. Rounds
Weighted Avg Weights the average of updates based on the number of each party sample. Use with training data sets of widely differing sizes. Rounds
XGBoost Machine learning algorithms with XGBoost. XGBoost Classification Use to perform classification tasks that use XGBoost. - Learning rate
- Loss
- Rounds
- Number of classes
XGBoost Regression - Learning rate
- Rounds
- Loss
K-Means/SPAHM Used to train KMeans (unsupervised learning) models when parties have heterogeneous data sets. Max Iter
- N cluster
Pytorch
Pytorch can be used for training neural network models in Python and more.
Any N/A Simple Avg Simplest aggregation that is used as a baseline where all parties' model updates are equally weighted. - Rounds
- Epochs
- Quorum (Optional)
- Max Timeout (Optional)
Neural Networks A method of using layers of neurons to train deep learning models. Probabilistic Federated Neural Matching (PFNM) Communication-efficient method for fully-connected neural networks when parties have heterogeneous datasets. - Rounds
- Termination accuracy (Optional)
- Epochs
- sigma
- sigma0
- gamma
- iters

Frameworks and Python version compatibility

Watson Studio Framework Python Selection Software Spec Party Python Version Python Client Extras Party Framework
scikit-learn 3.9 runtime-22.1-py3.9 3.9 fl-rt22.1 scikit-learn 1.0.2
Tensorflow 3.9 runtime-22.1-py3.9 3.9 fl-rt22.1 tensorflow 2.7.1
PyTorch 3.9 runtime-22.1-py3.9 3.9 fl-rt22.1 torch 1.10.2

Hyperparameters table

Hyperparameters Description
Rounds Int value. The number of training iterations to complete between the aggregator and the remote systems.
Termination accuracy (Optional) Float value. Takes:
- model_accuracy
and compares it to a numerical value. If the condition is satisfied, then the experiment finishes early.

For example: termination_predicate: accuracy >= 0.8 finishes the experiment when the mean of model accuracy for participating parties is greater than or equal to 80%. Currently Federated Learning accepts one type of early termination condition (model accuracy) for classification models only.
Quorum (Optional) Float value. Proceeds with model training after the aggregator reaches a certain ratio of party responses. Takes a decimal value in the range 0 - 1. The default is 1. The model training starts only after party responses reach the indicated ratio value.
For example, setting this value to 0.5 starts the training after 50% of the registered parties have responded to the aggregator call.
Max Timeout (Optional) Int value. Terminates the Federated Learning experiment if the waiting time for party responses exceeds this value in seconds. Takes a numerical value up to 43200. If this value in seconds passes and the quorum ratio is not reached, the experiment terminates.
For example, max_timeout = 1000 terminates the experiment after 1000 seconds if the parties do not respond in that time.
Number of classes Int value. Number of target classes for the classification model. Required if "Loss" hyperparameter is:
- auto
- binary_crossentropy
- categorical_crossentropy
Learning rate Decimal value. The learning rate, also known as shrinkage. This is used as a multiplicative factor for the leaves values.
Loss String value. The loss function to use in the boosting process.
- binary_crossentropy (also known as logistic loss) is used for binary classification.
- categorical_crossentropy is used for multiclass classification.
- auto chooses either loss function depending on the nature of the problem.
- least_squares is used for regression.
Max Iter Int value. The total number of passes over the local training data set to train a Scikit-learn model.
N cluster Int value. The number of clusters to form as well as the number of centroids to generate.
sigma Float value. Determines how far the local model neurons are allowed from the global model. A bigger value allows more matching and produces a smaller global model. Default value is 1.
sigma0 Float value. Defines the permitted deviation of the global network neurons. Default value is 1.
gamma Float value. Indian Buffet Process parameter that controls the expected number of features in each observation. Default value is 1.

Parent topic: Creating the Federated Learning Experiment

Generative AI search and answer
These answers are generated by a large language model in watsonx.ai based on content from the product documentation. Learn more