Federated Learning, Part 2: Implementation with the Flower Framework 🌼


in the federated learning series I am doing, and if you just landed here, I would recommend going through the first part where we discussed how federated learning works at a high level. For a quick refresher, here is an interactive app that I created in a marimo notebook where you can perform local training, merge models using the Federated Averaging (FedAvg) algorithm and observe how the global model improves across federated rounds. 

FL marimo
An interactive visualization of federated learning where you control the training process and watch the global model evolve. (Inspired by AI Explorables)

In this part, our focus will be on implementing the federated logic using the Flower framework.

What happens when models are trained on skewed datasets

In the first part, we discussed how federated learning was used for early COVID screening with Curial AI. If the model had been trained only on data from a single hospital, it would have learnt patterns specific to that hospital only and would have generalised badly on out-of-distribution datasets. We know this is a theory, but now let us put a number to it. 

I am borrowing an example from the Flower Labs course on DeepLearning.AI because it uses the familiar which makes the idea easier to understand without getting lost in details. This example makes it easy to understand what happens when models are trained on biased local datasets. We then use the same setup to show how federated learning changes the outcome.

  • I have made a few small modifications to the original code. In particular, I use the Flower Datasets library, which makes it easy to work with datasets for federated learning scenarios.
  • 💻 You can access the code here to follow along. 

Splitting the Dataset

We start by taking the MNIST dataset and splitting it into three parts to represent data held by different clients, let’s say three different hospitals. Additionally, we remove certain digits from each split so that all clients have incomplete data, as shown below. This is done to simulate real-world data silos.

image 176
Simulating real-world data silos where each client sees only a partial view.

As shown in the image above, client 1 never sees digits 1, 3 and 7. Similarly, client 2 never sees 2, 5 and 8 and client 3 never sees 4, 6, and 9. Even though all three datasets come from the same source, they represent pretty different distributions.

Training on Biased Data

Next, we train separate models on each dataset using the same architecture and training setup. We use a very simple neural network implemented in PyTorch with just two fully connected layers and train the model for 10 epochs.

image 177
Loss curves indicate successful training on local data, but testing will reveal the impact of missing classes.

As can be seen from the loss curves above, the loss gradually goes down during training. This indicates that the models are learning something. However, remember, each model is only learning from its own limited view of the data and it’s only when we test it on a held-out set that we’ll know the true accuracy.

Evaluating on Unseen Data

To test the models, we load the MNIST test dataset with the same normalization applied to the training data. When we evaluate these models on the complete test set (all 10 digits), accuracy lands around 65 to 70 percent, which seems reasonable given that three digits were missing from each training dataset. At least the accuracy is better than the random chance of 10%.

Read Also:  Platform-Mesh, Hub and Spoke, and Centralised | 3 Types of data team

Next, we also evaluate how individual models perform on data examples that were not represented in their training set. For that, we create three specific test subsets:

  • Test set [1,3,7] only includes digits 1, 3, and 7
  • Test set [2,5,8] only includes digits 2, 5, and 8
  • Test set [4,6,9] only includes digits 4, 6, and 9
Models perform reasonably on all digits but completely fail on classes they never saw during training

When we evaluate each model only on the digits it never saw during training, accuracy drops to 0 percent. The models completely fail on classes they were never exposed to. Well, this is also expected since a model cannot learn to recognize patterns it has never seen before. But there is more than what meets the eye, so we next look at the confusion matrix to understand the behavior in more detail.

Understanding the Failure Through Confusion Matrices

Below is the confusion matrix for model 1 that was trained on data excluding digits 1, 3, and 7. Since these digits were never seen during training, the model almost never predicts those labels. 

However, In few cases, the model predicts visually similar digits instead. When label 1 is missing, the model never outputs 1 and instead predicts digits like 2 or 8. The same pattern appears for other missing classes. This means that the model fails in a way by assigning high confidence to the wrong label. This is definitely not expected.

image 178
The confusion matrix shows how missing training data leads to systematic misclassification: absent classes are never predicted, and similar-looking alternatives are assigned with high confidence

This example shows the limits of centralized training with skewed data. When each client has only a partial view of the true distribution, models fail in systematic ways that overall accuracy does not capture. This is exactly the problem federated learning is meant to address and that’s what we will implement in the next section using the Flower framework.

What is Flower 🌼 ?

Flower is an open source framework that makes federated learning very easy to implement, even for beginners. It is framework agnostic so you don’t have to worry about using PyTorch, TensorFlow, Hugging Face, JAX and more. Also, the same core abstractions apply whether you are running experiments on a single machine or training across real devices in production.

Flower models federated learning in a very direct way. A Flower app is built around the same roles we discussed in the previous article: clients, a server and a strategy that connects them. Let’s now look at these roles in more detail.

Understanding Flower Through Simulation

Flower makes it very easy to start with federated learning without worrying about any complex setup. For local simulation, there are basically two commands you need to care about: 

  • one to generate the app — flwr new and 
  • one to run it—flwr run

You define a Flower app once and then run it locally to simulate many clients. Even though everything runs on a single machine, Flower treats each client as an independent participant with its own data and training loop. This makes it much easier to experiment and test before moving to a real deployment.

Read Also:  Solving the mystery of how an ancient bird went extinct

Let us start by installing the latest version of Flower, which at the time of writing this article is 1.25.0.

# Install flower in a virtual environment
pip install -U flwr 

# Checking the installed version
flwr --version
Flower version: 1.25.0

The fastest way to create a working Flower app is to let Flower scaffold one for you via flwr new.

flwr new #to select from a list of templates

or

flwr new @flwrlabs/quickstart-pytorch #directly specify a template

You now have a complete project with a clean structure to start with.

quickstart-pytorch
├── pytorchexample
│   ├── client_app.py   
│   ├── server_app.py   
│   └── task.py         
├── pyproject.toml      
└── README.md

There are three main files in the project:

  • The task.py file defines the model, dataset and training logic. 
  • The client_app.py file defines what each client does locally. 
  • The server_app.py file coordinates training and aggregation, usually using federated averaging but you can also modify it.

Running the federated simulation

We can now run the federation using the commands below.

pip install -e . 
flwr run .

This single command starts the server, creates simulated clients, assigns data partitions and runs federated training end to end. 

image 179

An important point to note here is that the server and clients do not call each other directly. All communication happens using message objects. Each message carries model parameters, metrics, and configuration values. Model weights are sent using array records, metrics such as loss or accuracy are sent using metric records and values like learning rate are sent using config records. During each round, the server sends the current global model to selected clients, clients train locally and return updated weights with metrics and the server aggregates the results. The server may also run an evaluation step where clients only report metrics, without updating the model.

If you look inside the generated pyproject.toml, you will also see how the simulation is defined. 

[tool.flwr.app.components]
serverapp = "pytorchexample.server_app:app"
clientapp = "pytorchexample.client_app:app"

This section tells Flower which Python objects implement the ServerApp and ClientApp. These are the entry points Flower uses when it launches the federation.

[tool.flwr.app.config]
num-server-rounds = 3
fraction-evaluate = 0.5
local-epochs = 1
learning-rate = 0.1
batch-size = 32

[tool.flwr.federations]
default = "local-simulation"

[tool.flwr.federations.local-simulation]
options.num-supernodes = 10

Next, these values define the run configuration. They control how many server rounds are executed, how long each client trains locally and which training parameters are used. These settings are available at runtime through the Flower Context object.

[tool.flwr.federations]
default = "local-simulation"

[tool.flwr.federations.local-simulation]
options.num-supernodes = 10

This section defines the local simulation itself. Setting options.num-supernodes = 10 tells Flower to create ten simulated clients. Each SuperNode runs one ClientApp instance with its own data partition.

Here is a quick rundown of the steps mentioned above.

Flower

Now that we have seen how easy it is to run a federated simulation with Flower, we will apply this structure to our MNIST example and revisit the skewed data problem we observed earlier.

Improving Accuracy through Collaborative Training

Now let’s go back to our MNIST example. We saw that the models trained on individual local datasets didn’t give good results. In this section, we change the setup so that clients now collaborate by sharing model updates instead of working in isolation. Each dataset, however, is still missing certain digits like before and each client still trains locally.

Read Also:  5 Useful Python Scripts to Automate Data Cleaning

The best part about the project obtained through simulation in the previous section is that it can now be easily adapted to our use case. I have taken the flower app generated in the previous section and made a few changes in the client_app ,server_app and the task file. I configured the training to run for three server rounds, with all clients participating in every round, and each client training its local model for ten local epochs. All these settings can be easily managed via the pyproject.toml file. The local models are then aggregated to a single global model using Federated Averaging.

image 181
1YKjqn51LtKnvNYP JWmFUw
The global federated model achieves 95.6% overall accuracy and strong performance (93–97%) on all digit subsets, including those missing from individual clients.

Now let’s look at the results. Remember that in the isolated training approach, the three individual models achieved an accuracy of roughly between 65 and 70%. Here, with federated learning, we see a massive jump in accuracy to around 96%. This means that the global model is much better than any of the individual models trained in isolation.

This global model even performs better on the specific subsets (the digits that were missing from each client’s data) and sees a jump in accuracy from previously 0% to between 94 and 97%. 

1ejAsp8 F3L4 HZiFu1UINA
Unlike the individual biased models, the federated global model successfully predicts all digit classes with high accuracy 

The confusion matrix above corroborates this finding. It shows the model learns how to classify all digits properly, even the ones to which it was not exposed. We don’t see any columns that only have zeros in them anymore and every digit class now has predictions, showing that collaborative training enabled the model to learn the complete data distribution without any single client having access to all digit types.

Looking at the big picture 

While this is a toy example, it helps to provide the intuition behind why federated learning is so powerful. This same principle can be applied to situations where data is distributed across multiple locations and cannot be centralized due to privacy or regulatory constraints. 

10X2yaW1lcOSsJzibviaWaw
Isolated training keeps data siloed with no collaboration (left) while federated learning enables hospitals to train a shared model without moving data (right).

For instance, if you substitute the above example with, let’s say, three hospitals, each having local data, you would see that even though each hospital only has its own limited dataset, the overall model trained through federated learning would be much better than any individual model trained in isolation. Additionally, the data stays private and secure in each hospital but the model benefits from the collective knowledge of all participating institutions. 

Conclusion & What’s Next

That’s all for this part of the series. In this article, we implemented an end-to-end Federated Learning loop with Flower, understood the various components of the Flower app and compared machine learning with and without collaborative learning. In the next part, we will explore Federated Learning from the privacy point of view. While federated learning itself is a data minimization solution since it prevents direct access to data, the model updates exchanged between client and server can still potentially lead to privacy leaks. Let’s touch upon this in the next part. For now, it’ll be a great idea to look into the official documentation.

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top