44 posts tagged with "ai"

View All Tags

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 8)

Welcome to the concluding part of our series! In this final installment, we'll delve deeper into the intricacies of the search system that accompanies our proposal. Additionally, we'll conduct a comprehensive evaluation to gauge the overall effectiveness and performance of our solution. Let's explore the various aspects of the search functionality and analyze how it contributes to the success of our proposed system. Let's embark on this journey together as we wrap up our discussion and reflect on the insights gained throughout this series.

1. Searching and Indexing Method in the CFOR System

To adapt our retrieval system for large-scale datasets, we've developed indexes for the CFOR system to facilitate nonexhaustive similarity search using GPU acceleration. Leveraging the searching algorithm outlined by Johnson et al. (referenced as "billion-scale similarity search with GPUs"), we've implemented it within the CFOR system for retrieval tasks. In searching, the CFOR system enhances accuracy by reducing the search space through additional information such as regions, categories, and attributes. For indexing, the object ontology aids in creating multi-indexing files to minimize search time. Our focus is on similarity search in vector collections using the L2 distance in the k-selection algorithm.

In the realm of searching, we distinguish between exact search (exhaustive search) and compressed search (greedy nonexhaustive search).

  • Exact search: Almost all searching algorithms in this type try to compute the full pairwise distance between the query and each data point in the database sequentially or using the index file.
  • Compressed-Domain search: Almost all searching algorithms in this type try to compute distance between the query and each data point in the database by applying space transformation, encoding, subspace splitting, or hashing. These methods can help improve searching time by using index files, but they have a trade-off in searching accuracy.

2. Data

Our fashion retrieval system was built on a subset of approximately 300,000 images of DeepFashion. In the DeepFashion dataset, objects from different aspects are caught in complicated background. The input image in the dataset is annotated with different labels based on details (fine-grained) of input of the current model concern, i.e., rich annotation. The samples given in Figures 1 and 2 show more details about the DeepFashion dataset.

Figure 1. Images from the DeepFashion dataset obtained from different views and complicated background.

Figure 2. Images from the DeepFashion dataset annotated with different labels based on details of input of the current model concern.

In testing, we employ part of the benchmark data to fine tune the trained models. We ensure that there are no fashion item overlaps between fine-tuning and testing sets. The dataset includes ∼220,000 images of the training set, 40,000 images of the validating set, and 40,000 images of the testing set split. However, in attribute learning, we limited the number of attribute labels used for testing and the number of training images for specific attributes to make an imbalanced attribute dataset so as to prove our proposed methods.

3. Results and Discussion

In the CFOR system, object ontology is useful in controlling training flow which impacts the performance of object category classification and attribute multitask classification. For object category classification, ontology controls the amount of training data through concepts. For attribute multitask classification, ontology manages local grouping which directly affects the performance of the proposed local imbalanced data solver on the large-scale dataset.

In this section, we will evaluate the effectiveness of different deep networks with the support of ontology on both category classification and attribute multitask classification in the CFOR system to pick out the best architecture for training the system. We will also compare our results with FashionNet.

Category Classification

We compare the performance between different deep architectures including NASNet, ResNet-18, ResNet-101, FashionNet, NASNet with average pooling dropout (NASNet APD) (proposed by us), and ResNet with average pooling dropout (ResNet APD) (proposed by us). These experiments will be evaluated by top-k accuracy (Figure.3). Our target is to find out the best possible architecture to apply as a core network of the CFOR system. This step can be mentioned as a preparation step before applying the CFOR system for fashion retrieval.

Figure 3. Accuracy plot for top-k accuracy in category classification.

The result of category classification by ResNet-18 APD is higher than 1.23% (at k= 1) after removing nodes and making average pooling in the ResNet-18 architecture (compared with the original ResNet-18 architecture). This increased value is 0.93% with the ResNet-101 architecture (compared with the original ResNet-101 architecture) and 0.02% with the NASNet v3 architecture (compared with the original NASNet v3 architecture). The ResNet-101 APD architecture (the best architecture addressed) outperformed the FashionNet architecture (the best performing architecture in category classification on the DeepFashion dataset versus others such as WTBI or DARN), and the value is 4.6% with k= 3 and 2.58% with k = 5. Based on the above experimental results, the ResNet-101 architecture provides better classification and higher performance compared to others (NASNet and ResNet-18). For this reason, we propose ResNet-101 as the core network architecture for training classification models.

Attribute Learning

Attribute multitask learning is an important part of the CFOR system. In this section, we evaluate the performance of the proposed local imbalanced data solver with MCC in dealing with the imbalanced attribute data on the large-scale fashion dataset.

Precision is the proportion of relevant instances among the retrieved instances which consider both true positives and false positives in each attribute. However, the number of true positives and false positives is bias because of the imbalanced data problem. Thus, precision can also be affected by the imbalanced data problem. Otherwise, recall, which cares about true-positive labels but not false-positive labels, will be used to evaluate experiments because of its good reflection for fewer data attributes.

Local MTL gets over STL and MTL in 28/35 attributes with a 54.70% recall rate (higher than that in STL (17.06%) and that in MTL (28.70%)). While a single task shows its weakness in fewer data attributes and multitasks get struggled with the serious imbalanced problem and lesser intergroup correlations in fashion data, local MTL can lower their negative influences as well as widen the positive effect of inner-group correlations on attribute learning. Thus, local MTL gets over STL and MTL in 13/15 fewer sample attributes (Figure.4).

Figure 4.Recall graph of 14 attributes in STL and local MTL.

Based on the experiment, comparison of chic, solid, and maxi attributes which have equal accuracy between MTL with and without MCC shows that MTL with MCC had higher recall compared to that without MCC in 20/35 remaining attributes. The overall performance increases about 3%. For attributes with fewer data, MTL with MCC had higher recall compared to that without MCC in 9/14 attributes. The overall performance for these fewer data attributes increases 5.14% (see Figure 5 for more details).

Figure 5.Recall graph of 35 attributes using local multitask models with and without MCC.

Retrieval in CFOR System Results

In this experiment, we test the retrieval ability of the CFOR system by using MAP from 1 retrieval result for each query (MAP@1) to 30 retrieval results for each query (MAP@30) so as to evaluate the effectiveness. The similarity retrieval experiment will check whether the extracted attributes in retrieved images are matched with ground-truth attributes in query image. The retrieval method will be based on deep features and over 35 attributes. After experimenting in 35 attributes belonging to 5 groups, the starting MAP@5 is acceptable (hovering 0.531) which shows the effectiveness of the searching method. The MAP@30 hovers 0.815, and the trend keeps rising which shows consistency and stabilization of information prediction methods in the CFOR system. A simple visualization of the retrieval process in the CFOR system is shown in Figure.6.

Figure 6. An example of the retrieval results of the CFOR system.

4. Conclusion

This work presents the coarse-to-fine object retrieval system, a learning framework for e-commerce online retrieval, which is supported to deal with large-scale imbalanced datasets. The framework can impact input and output as well as reconstruct datasets from the coarse-grained level to the fine-grained level and is believed to be an effective method in improving learning performance designed for retrieval. For input reconstruction, the framework based on ontology is used for threading training flow, local grouping in multitask attribute learning, and hierarchical storage and retrieval. For output optimization, we take advantage of MCC to minimize the effect of the imbalanced dataset on multitask attribute learning.

Through extensive experiments, we demonstrate the applicability of object ontology in improving training flow, the effectiveness of different deep networks (ResNet and NASNet) applied on important tasks in fine-grained retrieval, and the usefulness of local multitask attribute learning and an MCC-based imbalanced data solver in attribute multitask learning. The CFOR system is designed to have flexibility so that it can be optimized easily in the future.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 7)

To provide fine-grained information to the CFOR system, attribute learning is a most important task which should be optimized in both time-processing performance and ability to deal with large-scale imbalanced datasets.

1. Framework

Local multitask learning is applied to attribute learning. The proposed framework, depicted in Figure.1 and Figure.2 and comprising online and offline phases, consists of three main components. The initial component introduces a local multitask transfer learning model with a loss function designed to leverage inner-group correlations among attributes. The second component presents an imbalanced data resolver based on MCC (Matthews Correlation Coefficient) without any adjustments to the pretrained model or loss function. The third component discusses prior knowledge used for local attribute grouping to facilitate local multitask learning.

Figure 1. Local MTL with an imbalanced data problem solver framework (Offline phase).

The input and output of the learning framework will be images and their attribute vectors, respectively. However, with the local grouping role, the attribute vector’s size will be based on the number of attributes in each group. The dataset should be merged or split based on the local grouping role.

To evaluate the effectiveness of the proposed framework, we apply it in the fashion field and split the dataset into five local groups: fabric, shape, part, style, and texture. Because fashion has lesser intergroup correlations, the shared block should be designed to optimize the effectiveness of inner-group correlations to improve the overall performance. However, in crowd attributes (such as activities, locations, and participants), intergroup correlations should be taken into account to improve performance. Thus, the shared block should be modified to adapt to the context.

Figure 2. Local MTL with an imbalanced data problem solver framework (Online phase).

2. Deep Multitask Learning

Our aim is to estimate a number of fashion attributes via a joint estimation model. However, with the dynamic attributes, MTL which supports creating a joint estimation model becomes vulnerable in the training phase due to its nonusability when the number of attributes increases. Thus, the local grouping method can help solve this situation.

Framework in Detail

In the experiments, the proposed framework processes the query image and generates a confident score vector comprising 7 attribute scores per group across 5 groups, which is then thresholded to obtain binary outputs. The architecture is outlined in detail below.

Figure.1 illustrates the overall structure of the proposed method. For each group, a training set is assumed, consisting of NN fashion images, each with MM attributes. The dataset is represented as D=(Xi,Yi)D = {(X_i, Y_i)}, where XiX_i is the image and YiY_i is the label encoded as a one-hot vector. Inspired by prior researches, we employ an end-to-end DNN architecture as a shared block to learn joint representations for all tasks. The loss function employed is binary cross-entropy, and the activation function at the output layer is sigmoid, chosen for its simplicity and flexibility in modifying the DNN architecture.

Network Architecture

NASNet automatically generates network architectures, constructing an optimal model by initially creating architectures on a smaller dataset and then scaling them up to a larger one. Through experiments, the search for the best cells is conducted on the CIFAR-10 dataset, which are subsequently applied to the ImageNet dataset by stacking multiple copies of them, each with their own parameters. The resulting model demonstrates a 1.2% improvement in top-1 accuracy compared to the best human-designed architectures. NASNet proves its effectiveness over previous architectures and offers a transfer learning model trained on a diverse ImageNet dataset. Leveraging the pretrained NASNet model on ImageNet, transfer learning is applied to the DeepFashion dataset to expedite convergence and enhance performance. Additionally, a dropout layer is incorporated with NASNet to mitigate overfitting. While utilizing the NASNet model generation algorithm to tailor a model for the DeepFashion dataset is a promising approach, the time and hardware resources required for NASNet's model generation and training from scratch are significant. Due to hardware limitations, only transfer learning is employed.

Figure 3. Best normal cells and reduction cells identified with CIFAR-10 and ImageNet architecture (right) are built from the best convolutional cells . Zoph et al. built two types of cells because they want to create architectures for images of any size. While normal cells return a feature map which has the same dimension, reduction cells return a feature map with height and width reduced by a factor of two.

We will do experiments on NASNet architectures to find out which one is suitable for each specific task in our CFOR system. In our fashion retrieval experiments, the category classifier task and region classifier task are applied transfer with single-task learning, while fashion attribute recognition is applied local multitask learning. Besides, to adapt to large-scale datasets and reduce the effect of overfitting, we recommend changing the final fully connected layer to the global average pooling layer along with dropout.


In the next post, we will discover Searching and Indexing Method in the CFOR System as well as the effectiveness of the whole system.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 6)

In this section, we will mention about ontology, fashion ontology, and its related information and present the contributions of object ontology to the CFOR system.

1. Ontology Definition for CFOR System

As described by Guarino, ontology is a "formal, explicit specification of a shared conceptualization." Typically, ontologies consist of concepts and their hierarchical structure, aiding in organizing information within a domain. A complete ontology typically includes concepts, relations, and axioms. Additionally, ontologies offer several key advantages:

  • Describing domain knowledge through a semantic hierarchical tree, with concepts represented as nodes identified by words or phrases.
  • Bridging the semantic gap in various tasks, including those in computer vision and other disciplines.
  • Enhancing software engineering practices by improving flexibility, reliability, specification, and reusability.
  • Supporting multitask problem-solving capabilities.

Any proposed ontology should satisfy two fundamental criteria:

  • Wide recognition within the community.
  • Feasibility for formalization using mathematical expressions, enabling digitization.

In our approach, we employ ontological engineering to facilitate communication and information sharing across different levels of data abstraction involved in image fashion retrieval, detection, and information tagging.

The object ontology comprises two primary levels: coarse-grained and fine-grained.

  • At the coarse level, the object ontology includes regions, categories, or high-level conceptual entities, which leverage global features extracted by deep networks for similarity retrieval. However, these deep features are treated as black boxes, lacking explicit semantic information to aid users in their search process.
  • At the fine-grained level, the object ontology encompasses attributes that provide detailed descriptions of objects.

In our experiment, we focus on describing the object "Fashion." The fashion ontology is constructed using prior knowledge and information from the DeepFashion dataset, along with ontology definitions introduced by Guarino. See Figure.1 for an illustration of the fashion ontology.

Figure 1. Fashion ontology in general and a version of ontology for clothes.

The fashion ontology developed comprises three primary semantic levels: 1. Regions: Representing areas such as Top, Bottom, and Body. 2. Categories: Specific objects associated with each region, such as dresses or robes for the Body region. 3. Attributes: Describing detailed visual concepts like denim or fur.

To streamline the discussion, our investigation focuses on the object fashion across three regions (Top, Body, and Bottom), select categories within these regions, and their respective attributes.

Within the CFOR system, a query image undergoes processing starting from the coarse level of the object ontology to identify the region and category of the corresponding object. Subsequently, the object proceeds to the fine-grained concept ontology to ascertain attributes. Once all necessary information is obtained, the object undergoes indexing and similarity distance computation to identify similar images in the database, ranked by a cumulative score derived from similarity scores of global features and attribute learning between the query image and target database images. For a detailed illustration, refer to Figure.2.

Figure 2. An example of a relationship between the query image and semantic information from the coarse-grained level to the fine-grained level of the fashion ontology.

2. Fashion Object Ontology

In this section, we introduce the fashion object ontology. Within the fashion domain, we categorize semantic fashion concepts based on regions. Each region encompasses a detailed ontology comprising categories and attributes. To facilitate experiments using the DeepFashion dataset, we extend the fashion ontology within the "Clothes" branch (refer to Figure 9). It's essential to emphasize that the proposed ontology is not specific to any application and should be viewed as a flexible foundation.

The fashion object ontology consists of multiple levels of concepts, with relations between each level to articulate their associations. Two primary relations are employed: 1. "Part of": This relation specifies that the concepts are components of the main concept. 2. "Has a": This relation describes the main concept in detail.

For this study, we concentrate solely on the Clothes branch to ensure equitable comparisons with other methodologies. The Clothes taxonomy comprises 50 distinct categories. A clothing region taxonomy has been established (refer to Figure.3), organizing all clothing categories hierarchically. The first level of this hierarchy represents the most general clothing region, with three primary regions defined: 1. Top (e.g., tee and tank) 2. Bottom (e.g., skirt and jeans) 3. Body (e.g., dress and robe)

Figure 3. Excerpt from the “Clothes” taxonomy defined in the fashion ontology.

3. Fine-Grained Object Ontology

Fine-grained object ontology is used to describe objects at the attribute level. Semantic information such as attributes can be useful for a customer to retrieve (see Figure.4). It is important to note that the proposed ontology is not application dependent and should be considered as an extensible basis.

Figure 4. Fine-grained group at the attribute level.

Cloth attributes vary across different levels—some attributes, like color, are common across all cloth regions, while others are specific to certain regions or categories. Our ontology is structured into two main parts, each detailed in the following sections: 1. Specific fashion concepts—pertaining to particular characteristics of clothes such as fabric, part, and style. 2. Visual concepts—related to popular visual characteristics like color, shape, and texture, not exclusive to fashion.

Rudd et al. demonstrated in a study that a multitask learning-based model outperforms a combination of single-task learning-based models in face attribute prediction. While this approach shows promising results for fashion attributes as well, there's a significant difference in the quantity of attributes between faces and fashion items. This disparity can pose challenges in scaling the system, such as in training and storage requirements. To address this, we propose applying local multitask learning to attribute learning, providing more flexibility. Further explanation is provided in the subsequent sections.


In the next post, we will discover Attribute Learning and its correlation with multitask learning.

Reducing Non-IID Effects in Federated Autonomous Driving with Contrastive Divergence Loss (Part 3)

In the previous article, we delved into integrating the contrastive divergence loss function into federated learning, exploring its potential benefits for enhancing model performance and tackling non-IID data issues in autonomous driving contexts. In this follow-up piece, we delve into various federated learning configurations and present experimental findings that support the efficacy of our approach.

1. Implementation

Dataset: Our experimentation involves three datasets (see Table 1): Udacity+, Gazebo Indoor, and Carla Outdoor. Gazebo and Carla datasets exhibit non-IID characteristics, while Udacity+ represents a non-IID variant of the Udacity dataset.

DatasetTotal samplesAverage samples in each silo (Gaia)Average samples in each silo (NWS)Average samples in each silo (Exodus)

Table 1: The Statistic of Datasets in Our Experiments.

Network Topology: Our experimentation encompasses three distinct federated topologies, namely the Internet Topology Zoo (Gaia), North American data centers (NWS), and the Zoo Exodus network (Exodus). The primary focus is on the Gaia topology, while supplementary insights from the NWS and Exodus topologies are provided in our ablation study.

Training: Within each silo, model training is executed with a batch size of 32 and a learning rate set at 0.001, facilitated by the Adam optimizer. The local training regimen within each silo precedes the transmission and aggregation of models using the specified global aggregation equation. The training regimen spans 3,600 communication rounds and leverages a simulation environment akin to that described in Nguyen et al. (2022), powered by an NVIDIA 1080 GPU.

Baselines: Our comparative analysis involves several contemporary methods across diverse learning scenarios, including Random and Constant baselines as outlined by Loquercio et al. (2018). Within the Centralized Local Learning (CLL) scenario, we utilize Inception-V3, MobileNet-V2, VGG-16, and Dronet as baseline models. In the context of Server-based Federated Learning (SFL), our comparison extends to FedAvg, FedProx, and STAR. For the Decentralized Federated Learning (DFL) setting, our evaluation includes MATCHA, MBST, and FADNet. Model effectiveness is assessed using Root Mean Square Error (RMSE) and Mean Absolute Error (MAE), while wall-clock time (ms) serves as a metric for training duration.

2. Qualitative Results

In practice, we've noticed that the initial phases of federated learning often yield subpar accumulated models. Unlike other approaches that tackle the non-IID issue by refining the accumulation step whenever silos transmit their models, we directly mitigate the impact of divergence factors during the local learning phase of each silo. Our method aims to minimize the discrepancy between the distribution of accumulated weights from neighboring silos in the backbone network (representing divergence factors) and the weights specific to silo ii in the sub-network (comprising locally learned knowledge). Once the distribution between silos achieves an acceptable level of synchronization, we reduce the influence of the sub-network and prioritize the steering angle prediction task. Inspired by the contrastive loss of the original Siamese Network, our proposed Contrastive Divergence Loss is formulated as follows:

ModelMain FocusLearning MethodRMSE (Udacity+)RMSE (Gazebo)RMSE (Carla)MAE (Udacity+)MAE (Gazebo)MAE (Carla)# Training ParametersAvg. Cycle Time (ms)
InceptionArchitecture DesignCLL0.2090.0850.2970.1970.0620.20721,787,617_
MobileNetArchitecture DesignCLL0.1930.0830.2860.1760.0570.2002,225,153_
VGG-16Architecture DesignCLL0.1900.0830.3160.1610.0500.1847,501,587_
DroNetArchitecture DesignCLL0.1830.0820.3330.1500.0530.218314,657_
FedAvgAggregation OptimizationSFL0.2120.0940.2690.1850.0640.222314,657152.4
FedProxAggregation OptimizationSFL0.1520.0770.2260.1180.0630.151314,657111.5
STARAggregation OptimizationSFL0.1790.0620.2080.1490.0530.155314,657299.9
MATCHATopology DesignDFL0.1820.0690.2080.1480.0580.215314,657171.3
MBSTTopology DesignDFL0.1830.0720.2140.1490.0580.206314,65782.1
FADNetTopology DesignDFL0.1620.0690.2030.1340.0550.197317,72962.6
CDL (ours)Loss OptimizationCLL0.1690.0740.2660.1490.0530.172629,314_
CDL (ours)Loss OptimizationSFL0.1500.0600.2080.1040.0520.150629,314102.2
CDL (ours)Loss OptimizationDFL0.1410.0620.1830.0830.0520.147629,31472.7

Table 2: Performance comparison between different methods. The Gaia topology is used.

Table above summarizes the performance comparison between our proposed method and recent state-of-the-art approaches. The results indicate that our CDL under the Siamese setup with two ResNet-8 models outperforms other methods by a significant margin. Notably, our approach achieves substantial reductions in both RMSE and MAE across all three datasets: Udacity+, Carla, and Gazebo. Despite not increasing the network's parameter count, CDL introduces a larger model size during training due to the additional sub-network required by the Siamese setup. Additionally, our CDL with ResNet-8 demonstrates superior performance compared to other baselines, particularly in the DFL learning scenario, and to a lesser extent in SFL and CLL setups.

3. Contrastive Divergence Loss Analysis

CDL Performance Across Various Topologies In practice, training federated algorithms becomes more complex as the topology involves more vehicle data silos. To assess the efficacy of our CDL, we conduct training experiments and compare the results with other baseline methods across different topologies. Table 3 presents the performance comparison of DroNet, FADNet, and our CDL with a ResNet-8 backbone when trained using DFL across three distributed network infrastructures with varying numbers of silos: Gaia (11 silos), NWS (22 silos), and Exodus (79 silos). The table clearly demonstrates that our CDL consistently achieves superior results across all topology configurations. In contrast, DroNet encounters divergence issues, and FADNet exhibits suboptimal performance, particularly in the Exodus topology with 79 silos.

Gaia (11 silos)DroNet0.177 (↓0.036)0.073 (↓0.011)0.244 (↓0.061)
FADNet0.162 (↓0.021)0.069 (↓0.007)0.203 (↓0.020)
CDL (ours)0.1410.0620.183
NWS (22 silos)DroNet0.183 (↓0.045)0.075 (↓0.017)0.239 (↓0.057)
FADNet0.165 (↓0.027)0.070 (↓0.012)0.200 (↓0.018)
CDL (ours)0.1380.0580.182
Exodus (79 silos)DroNet0.448 (↓0.310)0.208 (↓0.147)0.556 (↓0.380)
FADNet0.179 (↓0.041)0.081 (↓0.020)0.238 (↓0.062)
CDL (ours)0.1380.0610.176

Table 3: Performance under different topologies.

CDL with Various Architectures Our CDL, functioning as a loss function, exhibits versatility across different network architectures when integrated into the Siamese setup, leading to performance enhancements. Figure.1 showcases the efficacy of CDL across diverse networks such as DroNet, FADNet, Inception, MobileNet, and VGG-16 within the Gaia Network framework in the DFL scenario. The outcomes demonstrate CDL's efficacy in mitigating the non-IID challenge across varied architectures, consistently elevating performance.

Figure 1. Performance of CDL under different networks in Siamese setup.

CDL with IID Data Figure.2 illustrates CDL's efficacy across various data distributions. While CDL is primarily tailored for addressing the non-IID challenge, it also demonstrates marginal performance enhancements when applied to models trained on IID data distributions. Leveraging the Siamese setup, CDL inherits traits and behaviors akin to triplet loss. Given triplet loss's established effectiveness with IID data, it becomes evident that CDL can similarly augment model performance in scenarios where IID data is utilized.

Figure 2. Performance of different methods on IID dataset (Udacity) and non-IID dataset (Udacity+).

3. Ablation Study

Figure.3 illustrates the training results in RMSE of our two baselines, DroNet and FADNet, as well as our proposed CDL. The results showcase the convergence ability of the mentioned methods across three datasets (Udacity+, Gazebo, Carla) with NWS and Gaia topology. It is evident that our proposed CDL can attain a superior convergence point compared to the two baselines. While other methods (DroNet and FADNet) struggled to converge or exhibited poor convergence trends, our proposed CDL demonstrates better overcoming of local optimal points and shows less bias towards any specific silo.

Figure 3. The convergence ability of different methods under Gaia topology (top row) and NWS topology (bottom row).

4. Discussion

Although our method exhibits promising results, there are several areas for potential improvement that warrant consideration for future work:

  • Our CDL is specifically designed to address the non-IID problem, meaning its effectiveness heavily depends on the presence of non-IID characteristics within the autonomous driving data. Thus, in scenarios where data distributions across vehicles are relatively consistent or lack significant non-IID factors, the proposed contrastive divergence loss may only lead to limited performance enhancements.

  • While our proposal has been validated on autonomous driving datasets, real-world testing on actual vehicles has not yet been conducted. Conducting a study involving various driving scenarios, such as interactions with pedestrians, cyclists, and other vehicles, could provide further validation of our method's efficacy.

  • Although our approach is tailored for autonomous driving applications, its underlying principles may have applicability in other domains where non-IID data is prevalent. Exploring its effectiveness in areas like healthcare, IoT, and industrial settings could expand its potential impact.


We presented a new method to address the non-IID problem in federated autonomous driving using contrastive divergence loss. Our method directly reduces the effect of divergence factors in the learning process of each silo. The experiments on three benchmarking datasets demonstrate that our proposed method performs substantially better than current state-of-the-art approaches. In the future, we plan to test our strategy with more data silos and deploy the trained model using an autonomous vehicle on roads.

Reducing Non-IID Effects in Federated Autonomous Driving with Contrastive Divergence Loss (Part 2)

In the previous post, we delved into the realm of federated learning in the context of autonomous driving, exploring its potential and the challenges posed by non-IID data distribution. We also discussed how contrastive divergence offers a promising avenue to tackle the non-IID problem in federated learning setups. Building upon that foundation, in this post, we will delve deeper into the integration of contrastive divergence loss into federated learning frameworks. We'll explore the mechanics of incorporating this loss function into the federated learning process and examine its potential implications for improving model performance and addressing non-IID data challenges in autonomous driving scenarios. We unravel the intricacies of leveraging contrastive divergence within federated learning paradigms to advance the capabilities of autonomous driving systems.

1. Overview

Motivation: The effectiveness of federated learning algorithms in autonomous driving hinges on two critical factors: firstly, the ability of each local silo to glean meaningful insights from its own data, and secondly, the synchronization among neighboring silos to mitigate the impact of the non-IID problem. Recent efforts have primarily focused on addressing these challenges through various means, including optimizing accumulation processes and optimizers, proposing novel network topologies, or leveraging robust deep networks capable of handling the distributed nature of the data. However, as highlighted by Duan et al., the indiscriminate adoption of high-performance deep architectures and their associated optimizations in centralized local learning scenarios can lead to increased weight variance among local silos during the accumulation process in federated setups. This variance detrimentally impacts model convergence and may even induce divergence, underscoring the need for nuanced approaches to ensure the efficacy of federated learning in autonomous driving contexts.

Figure 1. The Siamese setup when our CDL is applied for training federated autonomous driving model. ResNet-8 is used in the backbone and sub-network in the Siamese setup. During inference, the sub-network will be removed. Dotted lines represent the backward process. Our CDL has two components: the positive contrastive divergence loss Lcd+\mathcal{L}_{\rm {cd^+}} and the negative regularize term Lcd\mathcal{L}_{\rm {cd^-}}. The local regression loss Llr\mathcal{L}_{\rm {lr}} for automatic steering prediction is calculated only from the backbone network.

Siamese Network Approach: In our study, we propose a novel approach to directly tackle the non-IID problem within each local silo by addressing the challenges of learning optimal features and achieving synchronization separately. Our strategy involves implementing \textit{two distinct networks within each silo}: one network is dedicated to extracting meaningful features from local image data, while the other focuses on minimizing the distribution gap between the current model weights and those of neighboring silos. To facilitate this, we employ a Siamese Network architecture, comprising two branches. The first branch, serving as the backbone network, is tasked with learning local image features for autonomous steering using a local regression loss Llr\mathcal{L}_{\rm {lr}}, while simultaneously incorporating a positive contrastive divergence loss Lcd+\mathcal{L}_{\rm {cd^+}} to assimilate knowledge from neighboring silos. Meanwhile, the second branch, referred to as the sub-network, functions to regulate divergence factors arising from the backbone's knowledge through a contrastive regularizer term Lcd\mathcal{L}_{\rm {cd^-}}. See Figure.1 for more detail.

In practice, the sub-network initially adopts the same weights as the backbone during the initial communication round. However, starting from the subsequent communication rounds, once the backbone undergoes accumulation using Equation below, each silo's local model is trained using the contrastive divergence loss. The sub-network produces auxiliary features of identical dimensions to the output features of the backbone. Throughout training, we anticipate minimal discrepancies in weights between the backbone and the sub-network when employing the contrastive divergence loss. Synchronization of weights across all silos occurs when gradients from the backbone and sub-network learning processes exhibit minimal disparity.

θi(k+1)=θi(k)αk1mh=1mLlr(θi(k),ξih(k))\theta_i\left(k + 1\right) = {\theta}_i\left(k\right)-\alpha_{k}\frac{1}{m}\sum^m_{h=1}\nabla \mathcal{L}_{\rm {lr}}\left({\theta}_i\left(k\right),\xi_i^h\left(k\right)\right)

where Llr\mathcal{L}_{\rm {lr}} is the local regression loss for autonomous steering.

2. Contrastive Divergence Loss

In practice, we've noticed that the initial phases of federated learning often yield subpar accumulated models. Unlike other approaches that tackle the non-IID issue by refining the accumulation step whenever silos transmit their models, we directly mitigate the impact of divergence factors during the local learning phase of each silo. Our method aims to minimize the discrepancy between the distribution of accumulated weights from neighboring silos in the backbone network (representing divergence factors) and the weights specific to silo ii in the sub-network (comprising locally learned knowledge). Once the distribution between silos achieves an acceptable level of synchronization, we reduce the influence of the sub-network and prioritize the steering angle prediction task. Inspired by the contrastive loss of the original Siamese Network, our proposed Contrastive Divergence Loss is formulated as follows:

Lcd=βLcd++(1β)Lcd=βH(θib,θis)+(1β)H(θis,θib)\mathcal{L}_{\rm {cd}} = \beta \mathcal{L}_{\rm {cd^+}} + (1-\beta) \mathcal{L}_{\rm {cd^-}} = \beta \mathcal{H}(\theta^b_i, \theta^s_i) + (1-\beta) \mathcal{H}(\theta^s_i,\theta^b_i)

where Lcd+\mathcal{L}_{\rm {cd^+}} is the positive contrastive divergence term and Lcd\mathcal{L}_{\rm {cd^-}} is the negative regularizer term; H\mathcal{H} is the Kullback-Leibler Divergence loss function

H(y^,y)=f(y^)log(f(y^)f(y))\mathcal{H}(\hat{y},y) = \sum \mathbf{f}(\hat{y}) \log\left(\frac{\mathbf{f}(\hat{y})}{\mathbf{f}(y)}\right)

where y^\hat{y} is the predicted representation, yy is dynamic soft label.

Consider Lcd+\mathcal{L}_{\rm {cd^+}} in Equation above as a Bayesian statistical inference task, our goal is to estimate the model parameters θb\theta^{b*} by minimizing the Kullback-Leibler divergence H(θib,θis)\mathcal{H}(\theta^b_i, \theta^s_i) between the measured regression probability distribution of the observed local silo P0(xθis)P_0 (x|\theta^s_i) and the accumulated model P(xθib)P (x|\theta^b_i). Hence, we can assume that the model distribution has a form of P(xθib)=eE(x,θib)/Z(θib)P (x|\theta^b_i) = e^{-E(x,\theta^b_i)}/Z(\theta^b_i), where Z(θib)Z(\theta^b_i) is the normalization term. However, evaluating the normalization term Z(θib)Z(\theta^b_i) is not trivial, which leads to risks of getting stuck in a local minimum. Inspired by Hinton, we use samples obtained through a Markov Chain Monte Carlo (MCMC) procedure with a specific initialization strategy to deal with the mentioned problem. Additionally inferred from Equation above, the Lcd+\mathcal{L}_{\rm {cd^+}} can be expressed under the SGD algorithm in a local silo by setting:

Lcd+=xP0(xθis)E(x;θib)θib+xQθib(xθis)E(x;θib)θib\mathcal{L}_{\rm {cd^+}} = -\sum_{x}P_0 (x|\theta^s_i)\frac{\partial E(x;\theta^b_i)}{\partial \theta^b_i} + \sum_{x}Q_{\theta^b_i} (x|\theta^s_i)\frac{\partial E(x;\theta^b_i)}{\partial \theta^b_i}

where Qθib(xθis)Q_{\theta^b_i} (x|\theta^s_i) is the measured probability distribution on the samples obtained by initializing the chain at P0(xθis)P_0 (x|\theta^s_i) and running the Markov chain forward for a defined step.

Consider Lcd\mathcal{L}_{\rm {cd^-}} regularizer in Equation above as a Bayesian statistical inference task, we can calculate Lcd\mathcal{L}_{\rm {cd^-}} as in Equation above, however, the role of θs\theta^s and θb\theta^b is inverse:

Lcd=xP0(xθib)E(x;θis)θis+xQθis(xθib)E(x;θis)θis\mathcal{L}_{\rm {cd^-}}=-\sum_{x}P_0 (x|\theta^b_i)\frac{\partial E(x;\theta^s_i)}{\partial \theta^s_i} + \sum_{x}Q_{\theta^s_i} (x|\theta^b_i)\frac{\partial E(x;\theta^s_i)}{\partial \theta^s_i}

We note that the key difference is that while the weight θib\theta^b_i of the backbone is updated by the accumulation process from Equation above, the weight θis\theta^s_i of the sub-network, instead, is not. This lead to different convergence behavior of contrastive divergence in Lcd+\mathcal{L}_{\rm {cd^+}} and Lcd\mathcal{L}_{\rm {cd^-}}. The negative regularizer term Lcd\mathcal{L}_{\rm {cd^-}} will converge to state θis\theta^{s*}_i provided Eθis\frac{\partial E}{\partial \theta^s_i} is bounded:

g(x,θis)=E(x;θis)θisxP0(x(θib,θis))E(x;θis)θisg(x,\theta^s_i) = \frac{\partial E(x;\theta^s_i)}{\partial \theta^s_i} -\sum_{x}P_0 (x|(\theta^b_i,\theta^s_i))\frac{\partial E(x;\theta^s_i)}{\partial \theta^s_i}
(θisθis){xP0(x)g(x,θis)x,xP0(x)Kθism(x,x)g(x,θis)}k1θisθis2(\theta^s_i - \theta^{s*}_i)\cdot\left\{ \sum_{x}{P_0(x)g(x,\theta^s_i)} - \sum_{x',x}{P_0(x')\mathbf{K}^m_{\theta^s_i}(x',x)g(x,\theta^{s*}_i})\right\}\geq \mathbf{k}_1|\theta^s_i-\theta^{s*}_i|^2

for any k1\mathbf{k}_1 constraint. Note that, Kθsm\mathbf{K}^m_{\theta^s} is the transition kernel.

Note that the negative regularizer term Lcd\mathcal{L}_{\rm {cd^-}} is only used in training models on local silos. Thus, it does not contribute to the accumulation process of federated training.

3. Total Training Loss

Local Regression Loss. We use mean square error (MAE) to compute loss for predicting the steering angle in each local silo. Note that, we only use features from the backbone for predicting steering angles.

Llr=MAE(θib,ξ^i)\mathcal{L}_{\rm {lr}} = \text{MAE}(\theta^b_i, \hat{\xi}_i )

where ξ^i\hat{\xi}_i is the ground-truth steering angle of the data sample ξi\xi_i collected from silo ii.

Local Silo Loss. The local silo loss computed in each communication round at each silo before applying the accumulation process is described as:

Lfinal=Llr+Lcd\mathcal{L}_{{\rm{final}}} = \mathcal{L}_{\rm {lr}} + \mathcal{L}_{\rm {cd}}

In practice, we observe that both the contrastive divergence loss Lcd\mathcal{L}_{\rm {cd}} to handle the non-IID problem and the local regression loss Llr\mathcal{L}_{\rm {lr}} for predicting the steering angle is equally important and indispensable.

Combining all losses together, at each iteration kk, the update in the backbone network is defined as:

θib(k+1)={jNi+{i}Ai,jθjb(k), if k0(modu+1),θib(k)αk1mh=1mLb(θib(k),ξih(k)),otherwise.\theta^b_i\left(k + 1\right) =\begin{cases} \sum_{j \in \mathcal{N}_i^{+} \cup{\{i\}}}\textbf{A}_{i,j}{\theta}^b_{j}\left(k\right), \textit{ \quad \quad \quad if k} \equiv 0 \pmod{u + 1},\\ {\theta}^b_i\left(k\right)-\alpha_{k}\frac{1}{m}\sum^m_{h=1}\nabla \mathcal{L}_{\rm {b}}\left({\theta}^b_i\left(k\right),\xi_i^h\left(k\right)\right), \text{otherwise.} \end{cases}

where Lb=Llr+Lcd+\mathcal{L}_{\rm {b}} = \mathcal{L}_{\rm {lr}} + \mathcal{L}_{\rm {cd^+}}, uu is the number of local updates.

In parallel, the update in the sub-network at each iteration kk is described as:

θis(k+1)=θis(k)αk1mh=1mLcd(θis(k),ξih(k))\theta^s_i\left(k + 1\right) ={\theta}^s_i\left(k\right)-\alpha_{k}\frac{1}{m}\sum^m_{h=1}\nabla \mathcal{L}_{\rm {cd^-}}\left({\theta}^s_i\left(k\right),\xi_i^h\left(k\right)\right)


In the next post, we will evaluate the effectiveness of Constrative Divergence Loss in dealing with Non-IID problem in Federated Autonomous Driving.

Reducing Non-IID Effects in Federated Autonomous Driving with Contrastive Divergence Loss (Part 1)

Federated learning has been widely applied in autonomous driving since it enables training a learning model among vehicles without sharing users' data. However, data from autonomous vehicles usually suffer from the non-independent-and-identically-distributed (non-IID) problem, which may cause negative effects on the convergence of the learning process. In this paper, we propose a new contrastive divergence loss to address the non-IID problem in autonomous driving by reducing the impact of divergence factors from transmitted models during the local learning process of each silo. We also analyze the effects of contrastive divergence in various autonomous driving scenarios, under multiple network infrastructures, and with different centralized/distributed learning schemes. Our intensive experiments on three datasets demonstrate that our proposed contrastive divergence loss significantly improves the performance over current state-of-the-art approaches.

1. Introduction

Autonomous driving represents a burgeoning field wherein vehicles navigate without human intervention, employing a blend of vision, learning, and control algorithms to perceive and react to environmental changes. While traditional approaches heavily rely on supervised learning and data collection for model training, recent efforts, such as those by Bergamini et al., Chen et al., and Wang et al., have proposed various solutions to different challenges in autonomous driving. However, data collection poses privacy concerns as user data is shared with third parties, prompting a shift towards federated learning (FL). FL allows multiple parties to collaboratively train models without sharing raw data, thereby preserving user privacy while enabling autonomous vehicles to collectively learn predictive models from diverse datasets.

Typically, there are two primary federated learning scenarios: Server-based Federated Learning (SFL) and Decentralized Federated Learning (DFL). In SFL, a central node coordinates the training process and aggregates contributions from all clients, enhancing data privacy by transmitting only model weights. However, the reliance on central nodes in SFL can potentially bottleneck the system due to data transmission constraints. In contrast, DFL operates without a central server, employing a fully distributed network architecture. In autonomous driving, several works have explored both DFL and SFL to address different problems such as collision avoidance, trajectory prediction, and steering prediction.

Figure 1. Sample viewpoints over three vehicles and their steering angle distributions in Carla dataset. The visualization shows that the three vehicles have differences in visual input as well as steering angle distribution.

In practical application, both Server-based Federated Learning (SFL) and Decentralized Federated Learning (DFL) approaches exhibit their own advantages and limitations, yet both are susceptible to the non-IID problem inherent in federated learning. The non-IID problem, as described by Sattler et al. and Wang et al., arises when data partitioning across silos exhibits significant distribution shifts. Particularly in autonomous driving, this issue poses significant challenges during the accumulation process across vehicle silos. Each vehicle's unique driving patterns, weather conditions, and road types contribute to differences in data distribution, exacerbating the non-IID problem. For instance, data collected from vehicles on highways may differ substantially from data collected in urban settings. This discrepancy can lead to difficulties in constructing robust and accurate learning models, where reliance on data from a specific context may result in poor performance in others.

Previous studies have tackled the non-IID problem through various approaches such as optimizing the accumulation step, incorporating normalization into the global model, fine-tuning global model weights via distillation, or aligning the distribution between the global model and local ones. These methods typically refine global model weights to adapt to divergence factors arising from diverse distributions across local datasets. However, this optimization process can be challenging due to factors like determining the optimal accumulation step size for each local silo, managing network topology disconnections and bottlenecks, and ensuring consistency between local and global objectives.

In the paper, we introduce a novel Contrastive Divergence Loss (CDL) to tackle the non-IID problem in autonomous driving. Unlike other methods that wait for local model learning before addressing the non-IID issue during global aggregation, our CDL loss directly mitigates the impact of divergence factors throughout each local silo's learning process. This approach simplifies the adaptation of distribution between neighboring silos compared to handling the non-IID problem in the global model. Consequently, each local model becomes more resilient to changes in data distribution, allowing the use of typical accumulation methods like FedAvg for global weights without concerns about non-IID effects on convergence. The intensive experiments on three autonomous driving datasets verify our observation and show significant improvements over state-of-the-art methods.

2. Related Works

Autonomous Driving: Autonomous driving has emerged as a prominent research area in recent years, with a focus on leveraging deep learning for various tasks such as object detection, trajectory prediction, and autonomous control. For instance, Xin et al. proposed a recursive backstepping steering controller, while Xiong et al. analyzed nonlinear dynamics behavior using proportional control laws. Yi et al. presented an algorithm for self-reconfigurable robots during waypoint navigation, and Yin et al. combined model predictive control with covariance steering theory for robust autonomous driving systems.

Federated Learning for Autonomous Driving: Federated learning offers a privacy-aware solution for collaborative machine learning without sharing local data. It has gained traction in autonomous driving and intelligent transport systems, addressing communication, computation, and storage efficiency. Various works have explored federated learning for autonomous driving tasks such as steering angle prediction and turning signal prediction. However, challenges remain, including non-IID data distribution among participants, which recent works primarily address through accumulation processes rather than local silo optimization.

Contrastive Divergence: Contrastive models have gained attention in federated learning for handling heterogeneity in local data distribution. While contrastive learning has been applied to various datasets, its potential for mitigating non-IID data effects in federated autonomous driving remains underexplored. Most existing works focus on optimizing frameworks in server-based or P2P federated learning, or adjusting the accumulation process, with limited consideration of contrastive loss function behavior in federated autonomous driving scenarios.

Figure 2. A federated autonomus driving system.

3. Preliminary

We summarize the notations of our paper in Table.1.

Table 1. Notations.

We consider each autonomous vehicle as a data silo. Our goal is to collaboratively train a global driving policy θ\theta from NN silos by aggregating all local learnable weights θi\theta_i of each silo. Each silo computes the current gradient of its local loss function and then updates the model parameter using an optimizer. Mathematically, in the local update stage, at each silo ii, in each iteration kk, the model weights of the local silo can be computed as:

θi(k+1)=θi(k)αk1mh=1mLlr(θi(k),ξih(k))\theta_i\left(k + 1\right) = {\theta}_i\left(k\right)-\alpha_{k}\frac{1}{m}\sum^m_{h=1}\nabla \mathcal{L}_{\rm {lr}}\left({\theta}_i\left(k\right),\xi_i^h\left(k\right)\right)

where Llr\mathcal{L}_{\rm {lr}} is the local regression loss for autonomous steering. To update the global model, each silo interacts with the associated ones through a predefined topology:

θ(k+1)=i=0N1ϑiθi(k)\theta\left(k + 1\right) = \sum^{N-1}_{i = 0 }\vartheta_i{\theta}_i\left(k\right)

In practice, the local model in each silo is a deep network that takes the RGB images as inputs and predicts the associated steering angles


In the next post, we will introduce our proposal with contrastive divergence loss.

Fine-Grained Visual Classification using Self-Assessment Classifier (Part 2)

In previous part, we have discussed about the proposal to deal with fine-grained image classification. In this part, we will verify the effectiveness and efficiency of the proposal.

1. Experimental Setup

Dataset. We evaluate our method on three popular fine-grained datasets: CUB-200-2011, Stanford Dogs and FGVC Aircraft (See Table 1).

DatasetTarget# Cate# Train# Test
Stanford DogsDog12012,0008,580

Table 1: Fine-grained classification datasets in our experiments.

Implementation. All experiments are conducted on an NVIDIA Titan V GPU with 12GB RAM. The model is trained using Stochastic Gradient Descent with a momentum of 0.9. The maximum number of epochs is set at 80; the weight decay equals 0.00001, and the mini-batch size is 12. Besides, the initial learning rate is set to 0.001, with exponential decay of 0.9 after every two epochs. Based on validation results, the number of top-k ambiguity classes is set to 10, while the parameters dϕd_{\phi}, α\alpha are set to 0.10.1 and 0.50.5, respectively.

Baseline. To validate the effectiveness and generalization of our method, we integrate it into 7 different deep networks, including two popular Deep CNN backbones, Inception-V3 and ResNet-50; and five fine-grained classification methods: WS, DT, WS_DAN, MMAL, and the recent transformer work ViT. It is worth noting that we only add our Self Assessment Classifier into these works, other setups and hyper-parameters for training are kept unchanged when we compare with original codes.

2. Experimental Results

Table.2 summarises the contribution of our Self Assessment Classifier (SAC) to the fine-grained classification results of different methods on three datasets CUB-200-2011, Stanford Dogs, and FGVC Aircraft. This table clearly shows that by integrating SAC into different classifiers, the fine-grained classification results are consistently improved. In particular, we observe an average improvement of +1.3+1.3, +1.2+1.2, and +1.2+1.2 in the CUB-200-2011, Stanford Dogs, and FGVC Aircraft datasets, respectively.

MethodsCUB-200-2011Stanford DogsFGVC Aircraft
Parts Models90.493.9_
ViT + DCAL91.4_91.5
Inception-V3+ SAC85.3 (+1.6)86.8 (+1.7)89.2 (+1.8)
ResNet-50+ SAC88.3 (+1.9)87.4 (+1.3)92.1 (+1.8)
WS+SAC89.9 (+1.1)92.5 (+1.1)93.2 (+0.9)
DT+SAC90.1 (+0.9)88.8 (+0.8)91.9 (+1.2)
MMAL+SAC90.8 (+1.2)91.6 (+1.0)95.5 (+0.8)
WS_DAN+SAC91.1 (+1.7)93.1 (+0.9)93.9 (+0.9)
ViT+SAC91.8 (+0.8)94.5 (+1.3)93.1 (+1.0)
Avg. Improvement+1.3+1.2+1.2

Table: Contribution (% Acc) of our Self Assessment Classifier (SAC) on fine-grained classification results.

3. Qualitative Results

Attention Maps. Figure.1 illustrates the visualization of attention maps between image feature maps and each ambiguity class. The visualization indicates that by employing our Self Assessment Classifier, each fine-grained class focuses on different informative regions.

Figure 1. The visualization of the attention map between image feature maps and different ambiguity classes from our method. The red-colored class label denotes that the prediction is matched with the ground-truth.

Prediction Results. Figure.2 illustrates the classification results and corresponding localization areas of different methods. In all samples, we can see that our SAC focuses on different areas based on different hard-to-distinguish classes. Thus, the method can focus on more meaningful areas and also ignore unnecessary ones. Hence, SAC achieves good predictions even with challenging cases.

Figure 2. Qualitative comparison of different classification methods. (a) Input image and its corresponding ground-truth label, (b) ResNet-50, (c) WS_DAN, (d) MMAL, and (e) Our SAC. Boxes are localization areas. Red color indicates wrong classification result. Blue color indicates correct predicted label.


We introduce a Self Assessment Classifier (SAC) which effectively learns the discriminative features in the image and resolves the ambiguity from the top-k prediction classes. Our method generates the attention map and uses this map to dynamically erase unnecessary regions during the training. The intensive experiments on CUB-200-2011, Stanford Dogs, and FGVC Aircraft datasets show that our proposed method can be easily integrated into different fine-grained classifiers and clearly improve their accuracy.

Fine-Grained Visual Classification using Self-Assessment Classifier (Part 1)

Extracting discriminative features plays a crucial role in the fine-grained visual classification task. Most of the existing methods focus on developing attention or augmentation mechanisms to achieve this goal. However, addressing the ambiguity in the top-k prediction classes is not fully investigated. In this paper, we introduce a Self Assessment Classifier, which simultaneously leverages the representation of the image and top-k prediction classes to reassess the classification results. Our method is inspired by self-supervised learning with coarse-grained and fine-grained classifiers to increase the discrimination of features in the backbone and produce attention maps of informative areas on the image. In practice, our method works as an auxiliary branch and can be easily integrated into different architectures. We show that by effectively addressing the ambiguity in the top-k prediction classes, our method achieves new state-of-the-art results on CUB200-2011, Stanford Dog, and FGVC Aircraft datasets. Furthermore, our method also consistently improves the accuracy of different existing fine-grained classifiers with a unified setup.

1. Introduction

The task of fine-grained visual classification involves categorizing images that belong to the same class (e.g., various species of birds, types of aircraft, or different varieties of flowers). Compared to standard image classification tasks, fine-grained classification poses greater challenges due to three primary factors: (i) significant intra-class variation, where objects within the same category exhibit diverse poses and viewpoints; (ii) subtle inter-class distinctions, where objects from different categories may appear very similar except for minor differences, such as the color patterns of a bird's head often determining its fine-grained classification; and (iii) constraints on training data availability, as annotating fine-grained categories typically demands specialized expertise and considerable annotation effort. Consequently, achieving accurate classification results solely with state-of-the-art CNN models like VGG is nontrivial.

Recent research demonstrates that a crucial strategy for fine-grained classification involves identifying informative regions across various parts of objects and extracting distinguishing features. A common approach to achieving this is by learning the object's parts through human annotations. However, annotating fine-grained regions is labor-intensive, rendering this method impractical. Some advancements have explored unsupervised or weakly-supervised learning techniques to identify informative object parts or region of interest bounding boxes. While these methods offer potential solutions to circumvent manual labeling of fine-grained regions, they come with limitations such as reduced accuracy, high computational costs during training or inference, and challenges in accurately detecting distinct bounding boxes.

n this paper, we introduce the Self Assessment Classifier (SAC) method to tackle the inherent ambiguity present in fine-grained classification tasks. Essentially, our approach is devised to reevaluate the top-k prediction outcomes and filter out uninformative regions within the input image. This serves to mitigate inter-class ambiguity and enables the backbone network to learn more discerning features. Throughout training, our method generates attention maps that highlight informative regions within the input image. By integrating this method into a backbone network, we aim to reduce misclassifications among top-k ambiguous classes. It's important to note that "ambiguity classes" refer to instances where uncertainty in prediction can lead to incorrect classifications. Our contributions can be succinctly outlined as follows:

  • We propose a novel self-class assessment method that simultaneously learns discriminative features and addresses ambiguity issues in fine-grained visual classification tasks.
  • We demonstrate the versatility of our method by showcasing its seamless integration into various fine-grained classifiers, resulting in improved state-of-the-art performance.

Figure 1. Comparison between generic classification and fine-grained classification.

2. Method Overview

We propose two main steps in our method: Top-k Coarse-grained Class Search (TCCS) and Self Assessment Classifier (SAC). TCCS works as a coarse-grained classifier to extract visual features from the backbone. The Self Assessment Classifier works as a fine-grained classifier to reassess the ambiguity classes and eliminate the non-informative regions. Our SAC has four modules: the Top-k Class Embedding module aims to encode the information of the ambiguity class; the Joint Embedding module aims to jointly learn the coarse-grained features and top-k ambiguity classes; the Self Assessment module is designed to differentiate between ambiguity classes; and finally, the Dropping module is a data augmentation method, designed to erase unnecessary inter-class similar regions out of the input image. Figure.2 shows an overview of our approach.

Figure 2. Method Overview.

3. Top-k Coarse-grained Class Search

The TCCS takes an image as input. Each input image is passed through a Deep CNN to extract feature map FRdf×m×n\textbf{\textit{F}} \in \mathbb{R}^{d_f \times m \times n} and the visual feature VRdv\textbf{\textit{V}} \in \mathbb{R}^{d_v}. m,nm, n, and dfd_f represent the feature map height, width, and the number of channels, respectively; dvd_v denotes the dimension of the visual feature V\textbf{\textit{V}}. In practice, the visual feature V\textbf{\textit{V}} is usually obtained by applying some fully connected layers after the convolutional feature map F\textbf{\textit{F}}.

The visual features V\textbf{\textit{V}} is used by the 1st1^{st} classifier, i.e., the original classifier of the backbone, to obtain the top-k prediction results. Assuming that the fine-grained dataset has NN classes. The top-k prediction results Ck={C1,...,Ck}C_k = \{C_1,..., C_k\} is a subset of all prediction classes CNC_N, with kk is the number of candidates that have the kk-highest confident scores.

4. Self Assessment Classifier

Our Self Assessment Classifier takes the image feature F\textbf{\textit{F}} and top-k prediction CkC_k from TCCS as the input to reassess the fine-grained classification results.

Top-k Class Embedding

The output of the TCCS module CkC_k is passed through the top-k class embedding module to output label embedding set Ek={E1,...Ei,...,Ek},i{1,2,...,k},EiRde\textbf{E}_k = \{E_1,...E_i,..., E_k\}, i \in \{1,2, ..., k\}, E_i \in \mathbb{R}^{d_{e}}. This module contains a word embedding layer~\cite{pennington2014glove} for encoding each word in class labels and a GRU~\cite{2014ChoGRU} layer for learning the temporal information in class label names. ded_{e} represents the dimension of each class label. It is worth noting that the embedding module is trained end-to-end with the whole model. Hence, the class label representations are learned from scratch without the need of any pre-extracted/pre-trained or transfer learning.

Given an input class label, we trim the input to a maximum of 44 words. The class label that is shorter than 44 words is zero-padded. Each word is then represented by a 300300-D word embedding. This step results in a sequence of word embeddings with a size of 4×3004 \times 300 and denotes as E^i\hat{E}_i of ii-th class label in CkC_k class label set. In order to obtain the dependency within the class label name, the E^i\hat{E}_i is passed through a Gated Recurrent Unit (GRU), which results in a 10241024-D vector representation EiE_i for each input class. Note that, although we use the language modality (i.e., class label name), it is not extra information as the class label name and the class label identity (for calculating the loss) represent the same object category.

Joint Embedding

This module takes the feature map F\textbf{\textit{F}} and the top-k class embedding Ek\textbf{E}_k as the input to produce the joint representation JRdj\textbf{\textit{J}} \in \mathbb{R}^{d_j} and the attention map. We first flatten F\textbf{\textit{F}} into (df×f)(d_f \times f), and Ek\textbf{E}_k is into (de×k)(d_e \times k). The joint representation J\textbf{\textit{J}} is calculated using two modalities F\textbf{\textit{F}} and Ek\textbf{E}_k as follows:

JT=(T×1vec(F))×2vec(Ek)\textbf{\textit{J}}^T= \left(\mathcal{T} \times_1 \text{vec}(\textbf{\textit{F}}) \right) \times_2 \text{vec}(\textbf{E}_k)

where TRdF×dEk×dj\mathcal{T} \in \mathbb{R}^{d_{\textbf{\textit{F}}} \times d_{\textbf{E}_k} \times d_j} is a learnable tensor; dF=(df×f)d_{\textbf{\textit{F}}} = (d_f \times f); dEk=(de×k)d_{\textbf{E}_k} = (d_e \times k); vec()\text{vec}() is a vectorization operator; ×i\times_i denotes the ii-mode tensor product.

In practice, the preceding T\mathcal{T} is too large and infeasible to learn. Thus, we apply decomposition solutions that reduce the size of T\mathcal{T} but still retain the learning effectiveness. We rely on the idea of the unitary attention mechanism. Specifically, let JpRdj\textbf{\textit{J}}_p \in \mathbb{R}^{d_j} be the joint representation of pthp^{th} couple of channels where each channel in the couple is from a different input. The joint representation J\textbf{\textit{J}} is approximated by using the joint representations of all couples instead of using fully parameterized interaction as in Eq.~\ref{eq:hypothesis}. Hence, we compute J\textbf{\textit{J}} as:

J=pMpJp\textbf{\textit{J}} = \sum_p \mathcal{M}_p \textbf{\textit{J}}_p

Note that in Equation above, we compute a weighted sum over all possible couples. The pthp^{th} couple is associated with a scalar weight Mp\mathcal{M}_p. The set of Mp\mathcal{M}_p is called the attention map M\mathcal{M}, where MRf×k\mathcal{M} \in \mathbb{R}^{f \times k}.

There are f×kf \times k possible couples over the two modalities. The representation of each channel in a couple is Fi,(Ek)j\textbf{\textit{F}}_{i}, \left(\textbf{E}_k\right)_{j}, where i[1,f],j[1,k]i \in [1,f], j \in [1,k], respectively. The joint representation Jp\textbf{\textit{J}}_p is then computed as follows

JpT=(Tu×1Fi)×2(Ek)j\textbf{\textit{J}}_p^T= \left(\mathcal{T}_{u} \times_1 \textbf{\textit{F}}_{i} \right)\times_2 \left(\textbf{E}_k\right)_{j}

where TuRdf×de×dj\mathcal{T}_{u} \in \mathbb{R}^{d_f \times d_e \times d_j} is the learning tensor between channels in the couple.

From Equation above, we can compute the attention map M\mathcal{M} using the reduced parameterized bilinear interaction over the inputs F\textbf{\textit{F}} and Ek\textbf{E}_k. The attention map is computed as

M=softmax((TM×1F)×2Ek)\mathcal{M} = \text{softmax}\left(\left(\mathcal{T}_\mathcal{M} \times_1 \textbf{\textit{F}} \right) \times_2 \textbf{E}_k \right)

where TMRdf×de\mathcal{T}_\mathcal{M} \in \mathbb{R}^{d_f \times d_e} is the learnable tensor.

The joint representation J\textbf{\textit{J}} can be rewritten as

JT=i=1fj=1kMij((Tu×1Fi)×2(Ek)j)\textbf{\textit{J}}^T= \sum_{i=1}^{f}\sum_{j=1}^{k} \mathcal{M}_{ij} \left( \left( \mathcal{T}_{u} \times_1 \textbf{\textit{F}}_{i}\right) \times_2 \left(\textbf{E}_k\right)_{j} \right)

It is also worth noting from Equation above that to compute J\textbf{\textit{J}}, instead of learning the large tensor TRdF×dEk×dj\mathcal{T} \in \mathbb{R}^{d_{F} \times d_{\textbf{E}_k} \times d_j}, we now only need to learn two smaller tensors TuRdf×de×dj\mathcal{T}_{u} \in \mathbb{R}^{d_{f} \times d_{e} \times d_j} in Eq.~\ref{eq:couplecompute} and $\mathcal{T}\mathcal{M} \in \mathbb{R}^{d_f \times d_e}$.

Self Assessment

The joint representation J\textbf{\textit{J}} from the Joint Embedding module is used as the input in the Self Assessment step to obtain the 2nd2^{nd} top-k predictions Ck\textbf{C}'_k. Note that Ck={C1,...,Ck}\textbf{C}'_k = \{C'_1,..., C'_k\}. Intuitively, Ck\textbf{C}'_k is the top-k classification result after self-assessment. This module is a fine-grained classifier that produces the 2nd2^{nd} predictions to reassess the ambiguity classification results.

The contribution of the coarse-grained and fine-grained classifier is calculated by

Pr(ρ=ρi)=αPr1(ρ=ρi)+(1α)Pr2(ρ=ρi)\text{Pr}(\small{\rho} = \small{\rho}_i) = \alpha \text{Pr}_1(\small{\rho} = \small{\rho}_i) + (1- \alpha) \text{Pr}_2(\small{\rho} = \small{\rho}_i)

where α\alpha is the trade-off hyper-parameter (0α1)\left(0 \leq \alpha \leq 1\right). Pr1(ρ=ρi),Pr2(ρ=ρi)\text{Pr}_1(\small{\rho} = \small{\rho}_i), \text{Pr}_2(\small{\rho} = \small{\rho}_i) denotes the prediction probabilities for class ρi\small{\rho}_i, from the coarse-grained and fine-grained classifiers, respectively.


In the next post, we will verify the effectiveness and efficiency of the method.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 5)

In the previous post, we have discussed about offline phase in details. In this post, we will discover the online phase in details.

1. States in Online Phase

The online phase of the CFOR system corresponding to the demonstration in Figure.1

Figure 1. Online stage of the CFOR system.

2. Technical Details

The used functions will be described as follows:

  • (i) detector(imgQuery): an object in an image is automatically detected by using a trained detector. In this function, we inherit the successful software YOLO (version 3.0) to identify fashion items. Besides, the items identified are also refined by the region identification model, which is trained by “classifyModel” function.
  • (ii) infor_extract(states, obj, onto, classifyModels, multitaskModel): for each query object, all attribute learning models trained in function "multitaskModel" and coarse classification models in function "classifyModel" are run. We extract the region ⟶ category ⟶ attributes and necessary features for each stage of the ontology.
  • (iii) query_expansion(infor, feat): query expansion based on the mean vector is used for reranking retrieval results.
  • (iv) compute_sim_score(database, infor, feat): for each pair of features, asymmetric distance is used to measure the dissimilarity distance between the query and the sample in the database
  • (v) ranking(scorelist, database, top__k, GPU_search = True): based on the score between the query and all samples in the database obtained from function "compute_sim_score," ranking is applied; smaller is better.
  • (vi)retrieval(indexes, score_list, database, GPU_search = True): the retrieval process contains 3 steps including feature retrieval, fine-grained retrieval, and query expansion. For global retrieval, global features of the query object obtained from function “inforextract” and the features of samples in the database are passed to function "ranking" to get 1st top-_m retrieval results. For fine-grained retrieval, attribute features of the query object obtained from function “inforextract” and the features of samples in 1st top-_m retrieval results are passed to function "ranking" to get 2nd top-k retrieval results. For query expansion, the mean vector is computed from 2nd top-k retrieval results, and each feature of 2nd top-k retrieval results is passed to function "ranking" to get final top-k retrieval results, i.e., query expansion-based reranking.

As described in Figure 1, the online phase of the CFOR system contains three stages which will be put into use in real time. They are given as follows.

3. Prediction Stage

This stage will take advantage of object ontology and classification models obtained from the offline phase and then makes predictions from coarse to fine for each query image:

Fine-grained information in terms of regions, categories, and attributes provides more options for a customer to give a full semantic query. The object will be predicted from coarse to fine. In turn, the region, category, and attribute will be predicted based on object ontology and a local MDNN. The object retrieval system uses extracted semantic information as the category and attribute to search in detail.

3. Dissimilarity Measuring Stage

This stage will take advantage of the database as well as the indexing file from the offline phase and a dissimilarity measure to get scores and then rank, rerank, and release retrieval results for each query image. This stage is based on the dissimilarity measure between attribute vectors of query images and database images:

Based on combination of K-nearest neighbour search in terms of L2 distance and asymmetric distance computation, we take advantage of parallel processing by GPU through the Faiss method to compute the distance from the query image to the necessary one in the database. The distance which is also called the score of each image in the database is then sorted to rank the dissimilarity. The smaller the score of the image, the more similar the query. Based on the number of retrieval images required or thresholds, we will have an appropriate cutoff in the score as well as the number of retrieval images. This kind of measurement is used to compute distance for both deep features vectors and attribute vectors.

4. Dissimilarity Measuring Stage

Query expansion is a technique that can help gather additional relevant information from the input to increase retrieval performance. The information can be relevant images, additional features, description, etc. based on the query expansion algorithms and data. In this stage, we would like to take advantage of the previous retrieval results and then expand the query by using the mean vector to rerank and get reranked retrieval results to improve retrieval performance.

Query expansion based on the mean vector is chosen among many methods, the mean vector computed from features of retrieval results and the features of input help reduce the bias between different considered features. Thus, the CFOR system can eliminate unrelated features; that is, retrieval features have high gap from the mean vector features, which helps reduce outliers and rise the precision score.

Query expansion based on computing mean vector is performed very fast, and it can take advantage of the Faiss similarity searching method as well. Query expansion can remove outliers, thanks to the statistic essence of the mean vector.


In the next post, we will mention the Fashion Ontology, a CFOR System Testing in Fashion.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 4)

In the previous post, we have discussed about offline phase and online phase in overview. In this post, we will discover the offline phase in details.

1. States in Offline Phase

This phase consisted of three substages:

  • Object Ontology Establishment Stage. This stage defines fashion ontology to control the training flow as well as the online retrieval flow which serves as a bridge between high-level concepts (objects and categories), midlevel concepts (attributes), and raw data.
  • Learning Stage. This stage exploits deep networks with transfer learning in dealing with the specific tasks including object part learning, category learning, and attribute learning.
  • Storing and Indexing Stage. This stage defines a way of storing data as well as making the index list to reduce retrieval or searching time.

From the offline phase, in this section, inherited from previous state-of-the-art methods, we will mention about object part extraction, transfer learning, and its role in the retrieval system as well as data storing. These modules are highly generalized to any object. Other issues including ontology, attribute learning, network architecture, and indexing strategy will be detailed in the following sections.

2. Loss Function

This function inherited the current state-of-the-art ResNet for classification, and cross entropy loss function is applied for multiclass classification in the category classification model and attribute classification model.

For attribute multitask classification models, the loss function is described as follows:

3. Technical Details

Object ontology which is designed manually based on professional experience and public dataset for the community. It is organized into the hierarchical semantic tree with three main levels: region level, category level, and attribute level. Regions, categories, and attributes are learned automatically based on the local MDNN. The DeepFashion dataset, the used functions will be described as follows:

  • (i)extract_predicates(dta): in a rich-annotated dataset, e.g., DeepFashion, a sample image can be annotated by many labels in different fine-grained levels. For each fine-grained level, the function is used to extract the unique possible labels of samples and then store these labels into a corresponding array. For example, in the DeepFashion dataset, Top, Bottom, and Body are unique labels belonging to one fine-grained level, and thus, they are stored into one array. Similarly, fabric, shape, part, style, and texture labels belong to one fine-grained level and are stored into one array.
  • (ii)build_ontology(predicates, prior): this matches the extracted level and its labels from each predicate array into the corresponding stage of the general ontology, i.e., prior. For example, Top, Bottom, and Body belong to one level which is matched with the region stage of the ontology. After the matching is finished, all other unused stages are eliminated from the general ontology to generate the adapted ontology, e.g., fashion ontology.
  • (iii)extract_state(onto): from the built ontology, all stages and their labels are searched and stored into arrays which will be used to reconstruct the data. For example, the region stage array contains three classes, and the category stage array contains 50 classes.
  • (iv)extract_nes_dta(dta, state, onto): based on the stage and the classes extracted from the “extract_state” function, the whole DeepFashion dataset will be split. Only samples having the labels belonging to the stage are stored as the training set of that stage in the ontology. For example, with the region classification model, only samples labelled Top, Body, or Bottom are used for training.
  • (v)classifyModel(architecture, state_dta): in the DeepFashion dataset, based on ontology, there are four classification models: region classification model and category classification model for the Top region, Body region, and Bottom region. These models are retrained from the ImageNet dataset using ResNet-10.
  • (vi)multitaskModel(group_state_dta, architecture, Matthrew_coef = True): for each group state in terms of the fine-grained attribute level, a multitask classification model is built, e.g., fabric attribute group classification model and style attribute group classification model. These models are retrained from the ImageNet dataset using NASNet v3. Besides, the attribute learning and the usage of MCC are mentioned for an imbalanced data solver.
  • (vii)indexing(state_sta): indexing files are created that will be used for run-time retrieval. The method is based on the nonexhaustive compressed-domain search with GPU.
  • (viii)build_storage(onto, states): storage structures are automatically created based on built-in object ontology and extracted states.
  • (ix)infor_extract(states, dta, onto, classifyModels, multitaskModel): for each sample in the database, all attribute learning models trained in “multitaskModel” function are run and then all possible attributes which are higher than thresholds are extracted.
  • (x)feat_extract(dta, onto, classificationModels, multitaskModel): for each sample in the database, the features of the pre-softmax layer in four models trained in “classifyModel” function are obtained.
  • (xi)structure(storage, feat_dta, info_dta, indexFiles): the database is automatically built based on extracted features, extracted information, index, and storage structure.

3.1. Object Part Extraction

For the aforementioned reasons, foreground objects should be extracted from background regions efficiently and accurately before entering the retrieval step. The target of object extraction is to filter the necessary specific subjects. This also improves the efficiency of the following modules as well as increases the overall system performance. There are many successful object detection methods. Among them, YOLO shows the state-of-the-art results. In our system, we inherited the successful software YOLO (version 3.0) to identify fashion items.

3.2. Transfer Learning

Transfer learning is one of the best methods to reduce training time, especially with complicated architectures such as ResNet or NASNet. The key issue is the initial parameters. In the first step of the training process, we have to generate these parameters with some unsupervised learning methods. However, the initial one will be far from the optimal one. In transfer learning, we will reuse the trained parameters on a large and diverse dataset (such as ImageNet dataset). By this way, our training process will be easier to meet convergence. Thus, it reduces the training time.

Transfer learning can be applied in different ways based on the size of the dataset and data similarity. There are four scenarios in total. First, if the data size is small while data similarity is high, we use the pretrained model as a feature extractor. Second, if the data size is small and data similarity is low, we freeze the top layers and train the remaining layers of the pretrained model. Third (ideal situation), if the data size is large and data similarity is high, we can retrain the model by using the weights initialized in the pretrained one. Fourth (worst situation), if the data size is large and data similarity is low, transfer learning cannot be applied, and we have to train our model from scratch. In our fashion example experiments, while DeepFashion is a large dataset and ImageNet (dataset used for transfer learning) is a high diversity one, we can use all of the initialized weights from the pretrained model.

According to our approach, transfer learning will be applied in region, category classification as well as attribute learning along with ResNet and NASNet architectures, respectively. It can also be used in global deep feature extraction to improve the overall retrieval performance.

3.3. Data Storing

Features extracted from the category classification task and attribute learning will be stored in a hierarchical semantic tree based on object ontology. All features belong to a leaf of object ontology and will be stored in one file. In case of the expansion of large-scale data, the mentioned files can be indexed and split with a corresponding mapping key for each image. The folders will be organized based on object ontology in which each name corresponds to each concept. To clarify, data storing for the proposed ontology is defined as follows (see Figure below for an example of data storing):

  • (i) All files are stored in a folder named “database,” which is denoted as the “Object” node.
  • (ii) Based on ontology, “Object” node contains 3 nodes at the “Region” semantic level. Thus, we have 3 smaller folders: “Top,” “Body,” and “Bottom.”
  • (iii) At the next stage of ontology, we have the “Category” semantic level. Thus, we have 50 folders representing all nodes of “Category.”
  • (iv) Finally, we have the “Attribute” semantic level standing for the leaf node state in ontology. At this state, all features belong to the same “Region” and “Category” and are stored in one file.

Figure 1. An example of the storing structure.


In the next post, we will mention the online phase in details.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 3)

In the previous post, we have discussed about the imbalance data problem annd the overview of the object retrieval system. The CFOR system is organized into two main phases: offline phase and online phase. In this post, we will discover the overall of the online phase and offline phase of the proposal.

1. Overview of Offline Phase

This phase is designed to generate object ontology, database, indexing file and region detection model, category classification model, and attribute classification model.

  • Object ontology is designed manually based on professional experience and public dataset for the community. It is organized into a hierarchical semantic tree with three main levels: region level, category level, and attribute level.
  • The database is generated to store the preextracted features, regions, categories, and attributes of all images in the dataset. It supports to reduce the online retrieval time and provides the necessary semantic information for each retrieved image.
  • The indexing file which is created to support fast mapping in the online phase of the CFOR retrieval system is the key to perform the retrieval task at runtime.
  • Regions, categories, and attributes are learned automatically based on the local MDNN. Detection models and classification models are created to extract or predict semantic information of the query image and dataset such as regions, categories, and attributes.

2. Overview of Online Phase

This phase of the CFOR system is designed to run the retrieval process including object detection, semantic information extraction, and query expansion and retrieval.

In the object detection stage, we use the trained object detector to detect objects in the query image. In the semantic information extraction stage, the built-in object ontology and classification models are used for extracting the necessary semantic information of each identified object. The extracted semantic information and deep global features of each detected object passed through the searching system along with the indexing file to quickly compute the score between the query object and the sample in the database. Retrieval is applied to rank and export the most similar images to the query object and their relevant information. Query expansion is optional and used to increase the retrieval performance with a trade-off for retrieval time.

The power of mutually supporting object ontology, local MDNN, and imbalanced data solver in the CFOR system: Figure.1 shows the operation of the CFOR system with the interaction of the three main modules object ontology, a local MDNN, and an imbalanced data solver to optimize the learning strategy and improve the overall retrieval performance on large-scale datasets.

Figure 1. Synthesis of object ontology, deep learning, and imbalanced data problem solver in the CFOR system.

Object ontology supports conducting the training flow (with a local MDNN) and retrieval flow (from the coarse-grained level to the fine-grained level) to save computational costs in the training stage and retrieval stage on large-scale datasets. Training flow also paves a way for applying transfer learning which may improve the convergence rate of deep networks. Object ontology which could transform the global imbalance of data into local imbalance of data based on fine-grained groups makes the imbalanced data problem easier to deal with.

Deep multitask NN supports to link the object ontology to the raw data effectively at the category level and attribute level by exploiting inner-group correlations and intergroup correlations. The object ontology supports to update the system at the local level with parallel processing based on the local MDNN. Therefore, CFOR is updated in a flexible manner on large-scale datasets with many variations.And the proposed imbalanced data solver based on MCC which addresses data imbalance has contributed effectively to increasing the quality of object ontology implementation without adjusting network architecture and data augmentation.

Algorithm and demonstration of the CFOR system: an online phase and offline phase (Figures 2 and 3) are used to analyze tasks in the CFOR system. These phases will be demonstrated in detail in this section.

Figure 2. Offline stage of the CFOR system.

Figure 3. Online stage of the CFOR system.

Besides, the CFOR system can be put into use as a general solution for retrieval. To evaluate the performance of the proposed system, fashion objects with attributes are selected in experiments.


In the next post, we will mention the offline phase in details.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 2)

In the previous post, we have discussed about multitask learning and object retrieval system. In this post, we will discover imbalance data problem and the main proposal.

1. Imbalanced Data Problem

Imbalanced data are the problem in machine learning in which the class distribution is not uniform between the classes. Usually, they are composed of two types of classes: the majority classes (positive) and the minority classes (negative). Recent research in machine learning shows that using an uneven distribution of class examples during learning can cause learning algorithms with misleading performance (bias). It means a classifier with high accuracy in the majority, but it gives poor accuracy in the minority class. In the case of attribute learning, an imbalance occurs if the number of instances in some attributes varies significantly in quantity compared to other attributes. To deal with this situation, in general, adjusting the distribution of classes is an essence of many popular methods to handle imbalanced data problems.

  • Data sampling: sampling-based methods such as upsampling, downsampling, or data augmentation are considered to be a solution for imbalanced data problems. In addition to making data more balanced, they can help reduce training time (downsampling) or make the learning process more efficient (upsampling). The best approach we know is SMOTE which can solve the situation by automatically generating additional data (upsampling) based on the original dataset. However, these methods increase overfitting when training (upsampling) or losing (downsampling) data. Data augmentation is proved to be robust in dealing with imbalanced training data. However, this method takes up a lot of training resources, and it is difficult to find a proper augmented dataset which is large enough to train. And it is very difficult (or impossible) to augment data to balance the attributes in datasets because each object usually has many attributes.

    Figure 1. A simple data sampling

  • Architecture, loss function, and metric configuration: other methods exploit network architectures, loss functions, or metrics to address the imbalanced data problem when training. The methods (at the algorithm level) enhance the existing classifier by adjusting algorithms to recognize the smaller classes. Internal techniques provide general solutions for the imbalanced data problem because these are not specific to particular problems. These approaches show better performance compared to data sampling; however, they are often difficult to implement as well as configure in the future. Therefore, they are not always the best choice in dynamic retrieval systems in which the attributes have a large variety.Threshold and output-based configuration: instead of generating more data or making changes in the model, these methods find the best thresholds based on output. The essence of these methods is to use scores that show the probability to indicate which test sample is a member of a class in producing several learners by changing the threshold for class members. These methods are particularly effective in resolving imbalanced data problems without changing the configuration in the model. Moreover, they also do not reduce data or increase overfitting. SVM is proposed to find these thresholds. However, Boughorbel et al. proposed Matthews’ correlation coefficient (MCC) to deal with imbalanced data in classification. Although SVM shows better performance, MCC consumes less resources and processing time compared to it. Based on the methods of many other researchers, we found a solution for multitask learning that is suitable to retrieval systems using the end-to-end DCNN for training and MCC for estimating thresholds to get final outputs.

Figure 2. Metric losses

2. Materials and Methods: CFOR System

The CFOR system is very complicated but easy to understand. We focus on the main points of the CFOR system.

CFOR is an object retrieval system integrated by object ontology, a local MDNN (NASNet and ResNet), and an imbalanced data solver (MCC) to improve the performance of the large-scale object retrieval system from the coarse-grained level (categories) to the fine-grained level (attributes) (see Figure below).

Figure 2. Synthesis of object ontology, deep learning, and imbalanced data problem solver in the CFOR system.

Query Image. For traditional content-based image retrieval systems, with query images, one is just able to retrieve the images ranked on visual similarity to query image. It is very difficult (or impossible) for users to provide semantic information to the system based on query images. But the interesting thing is that, in our CFOR system, this challenge has been solved. The semantic information of the query image is extracted automatically by the category and attribute classification system, and users can use the extracted semantic information during the retrieval process.An example is how users can query “Asian face” with only a query image; here, “Asian race” is semantic information. The traditional retrieval methods cannot meet this requirement because of the curse of semantic gap. And the CFOR system can recognize “Asian race” and use it to retrieve. Another example for “Fashion” object based on our CFOR system is described in Figure 3.

Figure 3. Extracting regions, categories, and attributes from a query image with trained models of the CFOR system. After that, users can use this semantic information to reduce the searching space.

From the query image, based on fashion ontology, the detector quickly identifies the region (Top and Bottom; see Figure 4).

Figure 4. Fashion ontology used to retrieve.

After that, the user selects the region (Top; see Figure 5); the CFOR system quickly identifies the category related to the Top region (category: Blazer). Later, specific concepts and visual concepts are extracted according to Blazer, and users can select some of them (or all of them) to retrieve. For user-friendly interaction, only extracted regions, categories, and attributes are shown. Other information such as global deep features, attribute vector, ontology, or group of attributes which are used as searching input of the system will not be displayed. In such a way, users can order the CFOR system at the semantic level, and they can achieve the results that match both the content and semantics of the query image.

Figure 5.Retrieval results.


In the next post, we will mention the online phase and offline phase of the proposed retrieval system.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 1)

Object retrieval plays an increasingly important role in video surveillance, digital marketing, e-commerce, etc. It is facing challenges such as large-scale datasets, imbalanced data, viewpoint, cluster background, and fine-grained details (attributes). This paper has proposed a model to integrate object ontology, a local multitask deep neural network (local MDNN), and an imbalanced data solver to take advantages and overcome the shortcomings of deep learning network models to improve the performance of the large-scale object retrieval system from the coarse-grained level (categories) to the fine-grained level (attributes). Our proposed coarse-to-fine object retrieval (CFOR) system can be robust and resistant to the challenges listed above. To the best of our knowledge, the new main point of our CFOR system is the power of mutual support of object ontology, a local MDNN, and an imbalanced data solver in a unified system. Object ontology supports the exploitation of the inner-group correlations to improve the system performance in category classification, attribute classification, and conducting training flow and retrieval flow to save computational costs in the training stage and retrieval stage on large-scale datasets, respectively. A local MDNN supports linking object ontology to the raw data, and an imbalanced data solver based on Matthews’ correlation coefficient (MCC) addresses that the imbalance of data has contributed effectively to increasing the quality of object ontology realization without adjusting network architecture and data augmentation. In order to evaluate the performance of the CFOR system, we experimented on the DeepFashion dataset. This paper has shown that our local MDNN framework based on the pretrained NASNet architecture has achieved better performance (14.2% higher in recall rate) compared to single-task learning (STL) in the attribute learning task; it has also shown that our model with an imbalanced data solver has achieved better performance (5.14% higher in recall rate for fewer data attributes) compared to models that do not take this into account. Moreover, MAP@30 hovers 0.815 in retrieval on an average of 35 imbalanced fashion attributes.

1. Introduction

Nowadays, object retrieval is facing some challenges and has some advantages.

Query format plays a very important role in large-scale object retrieval systems. Thus, the query format should be user-friendly and satisfy user requirements in practice.

Two query formats are popular these days: image-based format and text-based format. The text-based query format is being used widely in many searching systems. However, in many cases, it is very difficult to use query text to express the content that human would like to retrieve because words have some limitations in expressing visual information. Instead, a query image is worth more than thousand words; it allows customers to search objects without typing, and the most important thing is that it can retrieve the results based on content. Nevertheless, the limitations of the query image in expressing semantic information could decrease the overall retrieval performance. Thus, the query image and retrieval image with useful related information (regions, categories, fine-grained attributes, etc.) will be the interesting points that we have to focus on to improve the performance of the coarse-to-fine object retrieval system.

Object retrieval systems should meet the requirements of retrieving from large-scale datasets not only at the coarse level but also at the detailed level (or attribute level). For example, in face retrieval systems, facial attribute retrieval is often required. In fashion retrieval systems, fashion attribute retrieval is an indispensable requirement. In person reidentification systems, in the reidentification stage, besides using the global features of the whole human body, attribute vectors of the face and clothes are also being exploited effectively. In crowd attribute recognition systems, the useful attribute set consisted of location, participants, and activities.

Objects often have multiple attributes, and there are methods to retrieve objects at the attribute level from large-scale datasets without manual annotation. In attribute recognition, the traditional methods often waste a lot of time in selecting hand-crafted features for each attribute group during the trial-and-error process but do not always achieve the desired results. In recent years, the deep convolutional neural network (DCNN) has demonstrated high performance in many tasks in computer vision such as detection, classification, recognition, and retrieval. And without exception, the DCNN is also used for attribute learning, with only one network architecture, and the DCNN model can learn to recognize many attributes.

The performance of the DCNN-based attribute learning model will not achieve high rate if the set of attributes plays the same role in the network architecture at the output level and imbalanced data are unresolved. To exploit the inner-group correlations in coarse-grained groups or fine-grained groups, the DCNN often is revised to the deep multitask NN. The performance of classification will be improved if the elements of fine-grained category groups or fine-grained attribute groups could share similar learning features, so the slope of their error surface will become more uniform and the deep multitask learning algorithm can easily reach the global optimum effectively.

Object ontology plays an important role in category classification, attribute classification, and conducting training flow and retrieval flow to save computational costs in the training stage and retrieval stage on large-scale datasets, respectively. Thus, based on our experience in researching objects related to attributes such as face, cloth, person (reidentification), crowd (monitoring), and fast filters in large-scale object retrieval, we would like to introduce an object ontology as a hierarchical semantic tree with three levels: region, category, and attribute levels. The attribute level consisted of visual concepts and specific concepts. Visual concepts support linking common visual attributes to arbitrary objects.

We introduce an object ontology based on popular large-scale standard datasets in science community, so we hope that our ontology can meet the criterion “widely recognized in community.” And for criterion “realization,” we have proposed the local MDNN to support linking object ontology to the raw data. However, if object ontology could not be linked with high quality, it could not function effectively. And we have proposed the imbalanced data solver based on MCC to address data imbalance that has contributed effectively to increasing the quality of linking object ontology to raw data without adjusting network architecture and data augmentation.

We review some typical works based on object ontology, deep multitask neural networks, and imbalanced data solvers to highlight our contributions.

Most of the works only present the set of attributes in the form of item lists or item groups. A few works used the terminology "ontology", but to the best of our knowledge, there are not works that present the object ontology in full meaning of regions, categories, and attributes.

In [8], FashionNet handles the challenges as deformation and occlusions by explicitly predicting clothing landmarks and pooling features over the estimated landmarks, resulting in more discriminative cloth representation. The authors do not use the terminology "ontology", but the DeepFashion dataset is organized based on a hierarchical tree; it is only deployed according to fashion, and it includes a two-level tree: the first level consisted of 50 categories and the second level consisted of 5 attribute groups (texture, fabric, shape, part, and style) (it does not have color attribute). The coarse-grained groups (at the category level) or fine-grained groups (at the attribute level) have the same role in deep neural networks, and the imbalanced data solver has not been considered yet.

Our idea is to improve the performance of deep neural networks based on object ontology and imbalanced data solvers with inspiration from Gödel’s incompleteness theory. This theory shows the limitation of any consistent formal system as well as the limitation of specific methods in solving problems. When the deep network configuration method is not able to create such a large effect as in the early days it took place, it is necessary to integrate object ontology and imbalanced data solvers into deep learning. Based on appropriate interventions in inputs and outputs, we introduce a new method that can help improve the performance of the object retrieval system.

The main contributions of this paper are as follows.

  • Our proposed unified model consisted of object ontology, a local MDNN, and an imbalanced data solver to improve the performance of the large-scale object retrieval system from the coarse-grained level (categories) to the fine-grained level (attributes).

  • Our proposed object ontology is a hierarchical semantic tree consisting of three main levels: region, category, and attribute levels. It can support the optimal learning strategy and minimize the effect of semantic gap. It is useful to improve the performance of category classification, attribute classification, and conducting training flow and retrieval flow to save computational costs in the training stage and retrieval stage on large-scale datasets, respectively.

  • Our proposed local MDNN is inspired by multitask neural networks. It is based on NASNet, ResNet exploiting the local multitask neural network architecture, to improve the performance of category classification and attribute classification and for flexible system updates. The local MDNN supports linking object ontology to raw data and takes advantage of inner-group correlations of categories and attributes. If the inner-group correlations (or intergroup correlations) are exploited, the performance of classification will be improved because the elements of fine-grained categories or the fine-grained attribute group share similar learning features, the slope of their error surface becomes more uniform, and our deep local multitask learning algorithm can easily reach the global optimum effectively.

Data imbalances often occur for large-scale datasets. Data augmentation is almost impossible because each object can have multiple attributes. The solution based on the loss functions, as in [6], may be possible, but it cannot exploit transfer learning. Our proposed imbalanced data solver is inherited from MCC without adjusting network architecture and data augmentation. It is integrated into the local MDNN to improve the performance of category classification and attribute classification, but it can still exploit transfer learning to reduce computational costs in the training stage on large-scale datasets.

Our proposed query format is based on object ontology with semantic information such as regions, categories, and attributes extracted automatically from the query image. Therefore, we can express semantic information from the image to the retrieval process that the traditional methods have not implemented yet.

Figure 1. A usecase of object retrieval system.

2. Object Retrieval System

Fine-grained object retrieval is supposed to search for similar images that include specific object attributes. It declares a transition model from image retrieval to object attribute retrieval. Specifically, unlike traditional image retrieval systems where queries and results are often coarse (e.g., texts or images), fine-grained image retrieval aims to retrieve semantic information such as categories and attributes. In the fashion field, taking advantages of semantic information, an object retrieval method based on the combination of the global feature with fine-grained attribute information was introduced [8]. Inspired by previous works, we would like to propose a coarse-to-fine object retrieval system which not only takes advantage of the combination of the global feature with fine-grained attribute information but also optimizes the learning strategy based on ontology and resolves the imbalanced data problem by interfering with the output.

In addition to meeting the semantic retrieval results, the object retrieval system must handle large-scale problems to run in real time. However, most solutions did not take advantage of the power of GPUs for parallel processing which can significantly reduce feature-matching time and retrieval time. To leverage the support of GPUs, we inherited the search algorithm introduced by Johnson et al. (billion-scale similarity search with GPUs) which is a nonexhaustive similarity search. The search method perfectly suited the proposed CFOR system which further decreased searching time by creating multi-index files based on built-in object ontology.

Figure 2. Original object retrieval system.

3. Attribute Learning

Attribute learning is a backbone of CFOR, and it has strong effects on performance of fine-grained object retrieval. Therefore, attribute learning is considered one of the important parts of the learning strategy.

Attribute Learning.

This method is used for object recognition systems at the fine-grained level. Unlike learning methods that are used for the high-level concept, attribute learning supports a solution for midlevel semantic concepts or visual concepts which are known to have (more or less) correlations to each other. There are two main different learning methods: single-task learning and multitask learning.Single-task attribute learning: in this type, attributes have their own learning model. Therefore, it leads to the number of models equal to the number of attributes. Moreover, each attribute is treated separately, for which the inner-group correlations are not yet exploited. Many works are known in the fashion field by using single-task learning for fashion attributes. At that time, there were many challenges in multitask learning. A shared CNN is defined to pave a way in the final format of the multitask multilabel predictions. Therefore, multitask learning becomes possible.Multitask attribute learning: to apply this technique to attributes, samples will be collected by merging given datasets into one with one-hot binary vector demonstration. Like single-task learning, the input will be the image. Despite the output of single-task learning which is a value that describes the existence (or not) of an attribute in an image, the output of multitask learning will be a one-hot binary vector which describes the existence (or not) of a group of attributes. Rudd has shown that joint optimization over all attributes outperforms training a single independent network with the same architecture for each attribute, in which the feature space is optimized along with the classifier on a per-attribute basis, both in terms of accuracy and storage, processing efficiency. This result shows that the multitask approach is much more effective in exploiting latent correlations than independent classifiers used to learn them. Although multitask learning can yield better performance compared to single-task learning, its critical weakness is that the model cannot be reused when there is any attribute change. A retraining or additional model will be applied when a new attribute is added. Lack of reuse is the reason that multitask learning methods are not flexible for attributes that change frequently. To address these challenges, we propose that local multitask attribute learning be considered a grouping method based on object ontology to improve its reuse.

Figure 3. Attribute learning model based on deep features with SVM classifiers.
Figure 4. Attribute learning model based on adaptive attribute domain with independent deep convolutional neural networks.
Figure 5. Attribute learning model based on the end-to-end deep neural network as a shared block with adaptive loss function.


In the next post, we will mention Imbalanced Data Problem and our proposal in details.

Controllable Group Choreography using Contrastive Diffusion (Part 6)

In previous part, we have discovered the qualitative and quantitative expriments. In this part, let investigate ablation studies and user studies.

Our code can be found at: https://github.com/aioz-ai/GCD

1. Ablation Analysis

Loss Terms. The contribution of the geometric loss Lgeo\mathcal{L}_{\rm geo} and contrastive loss Lnce\mathcal{L}_{\rm nce} in GCD is thoroughly analyzed and presented in Table 1. The results demonstrate that both losses play a crucial role in enhancing the overall performance across all four evaluation protocols. In particular, it can be seen that the effect of Lgeo\mathcal{L}_{\rm geo} on realism metrics (FID, GMR, and PFC) is significant. This observation can be attributed to the fact that this loss improves the physical plausibility and naturalness of the dance motions, empirically mitigating common artifacts such as jittery motion or foot skating. By enforcing the geometric constraints, GCD can generate faithful motions that are on par with real dances. Moreover, the contrastive loss Lnce\mathcal{L}_{\rm nce} contributes positively to the favorable results in the synchrony measures (GMR and GMC). This loss term encourages the model to synchronize the movements of multiple dancers within a group, thus improving the harmonious coordination and cohesion of the generated choreographies. In general, the results of Lgeo\mathcal{L}_{\rm geo} and Lnce\mathcal{L}_{\rm nce} validate their importance across various evaluation metrics.

Table. 1. Global module contribution and loss analysis. Experiments are conducted on GCD with γ=0\gamma = 0 (neutral mode).

Group Global Attention. Results presented in first two lines of Table 1 demonstrate substantial improvements obtained by incorporating Group Global Attention into GCD. It clearly shows that without the Group Global Attention, the performance on the group dance metrics (GMR and GMC) is significantly degraded. We also observe that the removal of this block resulted in inconsistent movements across dancers, where they seem to dance in freestyle without any group choreographic rules and collide with each other in many cases, although they may still follow the rhythm of the music. Results suggest the vital importance of ensuring coherency and regulating collisions for visually appealing group dance animations.

2. User Study

Qualitative user studies are important for evaluating generative models as the perception of users tends to be the most relevant metric for many downstream applications. Therefore, we conduct user studies to evaluate our approach in terms of group choreography generation. We organized two separate studies and enlisted roughly 50 individuals with diverse backgrounds to participate in our experiment. Each participant should have some relevant experience in music and dance (at least 1 month of studying or working in dance-related professions). The age of participants varied between 20 and 50, with approximately 55\% female and 45\% male.

In the initial study, we requested the participants to evaluate the dancing animations based on three criteria: the naturalness of the dancing motions (Realism), how well the movements match the music (Music-Motion Correspondence), and how well the dancers interact or synchronize with each other (Synchronization between Dancers). Participants were asked to rate scores from 0 to 10 for each criterion, ranging from 0-very poor, 5-acceptable, to 10-very good. The collected scores were then normalized to range [0,1][0,1].

This user study encompassed a total of 1893189 * 3 samples with songs that are not present in the train set, including those generated from GDanceR, real dance clips from the dataset, and generated results from our proposed method in neutral mode. Figure.1 shows average scores for all mentioned targets across three experiments. Notably, the ratings of our method are significantly higher than GDanceR across all three criteria. We also perform Tukey honest significance tests to determine the significant differences among the three methods. For the first two criteria (Realism and Music-Motion Correspondence), we observe that the mean scores of all methods are significantly different with p<0.05p < 0.05. For synchronization critera, the differences are significant except for the scores between our method and real dances (p0.07p\approx0.07). This highlights that our method can even achieve comparable scores with real dances, especially in the synchronization evaluation. This can be attributed to the proposed contrastive diffusion strategy, which can effectively maintain a balance between the consistency of the movements and the group/audio context, as well as diversity in generated dances.

Figure. 1. User study results in three criteria: Realism; Music-Motion Correspondence; and Synchronization between Dancers

In the second study, we aim to assess the diversity and consistency of the generated dance outputs and determine if they met the expectations of the users. Specifically, participants were asked to assign scores ranging from 0 to 10 to evaluate the consistency and diversity of each dance clip, i.e., how synchronized or how distinctive movements between dancers does the group dance present. A lower score indicated higher consistency, while a higher score indicated greater diversity. These scores were subsequently normalized to [1,1][-1,1] to align with the studying range of the control parameter employed in our proposed method.

Figure 2 depicts a scatter plot illustrating the relationship between the scores provided by the participants and the γ\gamma parameter that was used to generate the dance samples. The parameter values were randomly drawn from a uniform distribution with range [1,1][-1,1] to create the animations along with randomly sampled musical pieces. The survey shows a strong correlation between the user scores and the control parameter, in which we calculated the correlation coefficient to be approximately 0.88. The results indicate that the diversity and consistency level of the generated group choreography samples is mostly in agreement with the user evaluation, as indicated by the scores obtained.

Figure. 2. Correlation between the controlling consistency/diversity and the scores provided by the users.


While controlling consistency and diversity in group dance generation by using our proposed GCD has numerous advantages and potentials, there are certain limitations. Firstly, it requires tuning the parameters and a complex system that is not trivial to train, to ensure that the generated dance motions can produce the desired level of similarity among dancers while still presenting enough variation to avoid repetitive or monotonous movements. This may involve long inference processes and may require significant computational resources in both the training and testing phases.

Secondly, over-controlling consistency and diversity may introduce constraints on the creative freedom of generated dances. While enforcing consistency can lead to synchronized and harmonious group movements, it may limit the possibility of exploring unconventional or new experimental dance styles. On the other hand, promoting diversity results in unique and innovative dance sequences, but it may sacrifice coherence and coordination among dancers.

{Although our model can synthesize semantically faithful group dance animation with effective coordination among dancers, it does not capture clear physical contact between dancers such as hand touching. This is because the data we used in training does not contain such detailed hand motion information. We think that exploring group dance with realistic physical hand interactions is a promising area for future work. Additionally, while our method offers a trade-off between diversity and consistency, achieving perfect alignment between high-diversity movements and music remains a challenging task. The diversity level among dancers and the alignment with music are also heavily influenced by the training data. We believe further efforts are required to reach this.}

Lastly, the subjective nature of evaluating consistency and diversity poses a challenge. Metrics for measuring these aspects may not be best fitted. We believe it is essential to consider diverse perspectives and demand domain experts to validate the effectiveness and quality of the generated dance motions.

To conclude, we have introduced GCD, a new method for audio-driven group dance generation that effectively controls the consistency and diversity of generated choreographies. By using contrastive diffusion along with the guidance technique, our approach enables the generation of a flexible number of dancers and long-term group dances without compromising fidelity. Through our experiments, we have demonstrated the capability of GCD to produce visually appealing and synchronized group dance motions. The results of our evaluation, including comparisons with existing methods, highlight the superior performance of our method across various metrics including realism and synchronization. By enabling control over the desired levels of consistency and diversity while preserving fidelity, our work has the potential for applications in entertainment, virtual performances, and artistic expression, advancing the effectiveness of deep learning in generative choreography.

Controllable Group Choreography using Contrastive Diffusion (Part 5)

In previous part, we have discovered the experimental setups and quality analysis. In this part, let investigate consistency and diversity factors in dance generation.

Our code can be found at: https://github.com/aioz-ai/GCD

1. Diversity and Consistency Analysis

Table.1 presents an in-depth analysis of our method's performance across seven evaluation metrics by adjusting the parameter γ\gamma to control the consistency and diversity of the generated group choreographies. The findings reveal that our GCD with high consistency setting (γ=1\gamma = 1), performs better than other settings in terms of MMC, GMC, and TIR metrics, whereas the high diversity setting (γ=1\gamma=-1) achieves better results in the GenDiv metric. Meanwhile, the default model shows the best performance in both realism metrics (FID and GMR). {It can also be seen that the model is relatively robust to the physical plausibility score (PFC) as there are no noticeable differences among the metric in the three settings. This implies that our model is able to create group animation with different consistency or diversity levels without compromising the plausibility of the movements too much.} More interestingly, we found that there are indeed positive correlations between the two measures MMC, GMR, and the trade-off parameter. It is clear that these metrics are better when the consistency level increases. This is reasonable as we expect higher correspondence between the motion and the music (MMC) or higher correlation of the group motions (GMR) when the consistency level grows, which also agrees with the definition of these metrics.

Table. 1. Performance comparison. High Consistency: parameter γ=1\gamma=1; High Diversity: parameter γ=1\gamma=-1; Neutral: parameter γ=0\gamma=0

Consistency setups in GCD lead to more similar movements between dancers. As a result, this similarity contributes to high scores in MMC, GMR, and TIR metrics. In contrast, the diversity setups can synthesize more complex motions with greater variation between dancers as measured by the GenDiv metric, but this also makes it more challenging to reach high values of FID, MMC, and TIR, compared with other setups. In addition, Figure 1 shows an example of correlations between motion beats and music beats under high consistency and diversity settings. The music beats are extracted using the beat tracking algorithm from the Librosa library. Notably, the velocity curves in high consistency setting display relatively similar shapes, whereas in the high diversity scenario, the curves are clearly distinguished among dancers.

Figure. 1. Correlation between the motion and music beats. The solid curve represents the kinetic velocity of each dancer over time and the vertical dashed line depicts the music beats of the sequence. The motion beats can be detected as the local extrema from the kinetic velocity curve.

Despite the greater variations in high diversity setting, we can observe that the generated motions are matched with the music as the music beats are mostly located near the extrema of the motion curves in both settings. The experiment indicates that our model can faithfully capture different aspects of group choreography with different settings, including diversity and synchrony of the motions. This demonstrates the potential of our method towards various dance applications such as dance training or composing. Furthermore, our method can also produce distinctively different animation sequences under the same setting while adhering to the input music.
It is also important to note that all three setups of GCD significantly outperform other baseline models. This verifies the effectiveness of our proposed approach and shows that it can create high-fidelity group dance animations in any setting. For a more detailed visualization of the results, please refer to Figure 2.

Figure. 2. Consistency and diversity trade-off. High Consistency: parameter γ=1\gamma=1; High Diversity: parameter γ=1\gamma=-1; Neutral: parameter γ=0\gamma=0.

2. Number of Dancers Analysis

Table 2 provides insights into results obtained when generating arbitrarily different numbers of dancers using our proposed GCD in the neutral setting. In general, FID, GMR, and GMC metrics do not exhibit a clearly strong correlation with the number of generated dancers but display diverse and varied results. The MMC metric consistently shows its stability across all setups.

Table. 2. Performance of group dance generation methods when we increase the number of generated dancers, compared with GDanceR. In GCD setup, Neutral mode with γ=0\gamma = 0 is used.

As the number of generated dancers increases, the generation diversity (GenDiv) decreases while the trajectory intersection frequency (TIF) increases. However, it is worth noting that the differences observed in these metrics are relatively minor compared to those produced by GDanceR. This implies that our method can effectively control consistency and diversity, significantly reducing the chances of collisions between dancers and maintaining the overall quality of generated group dance motions.

For a detailed visualization of results, please refer to Figure 3. These results underscore the robustness and flexibility of our method in generating group dance motions across varying numbers of dancers while ensuring consistency, diversity, and avoiding collisions between performers.

Figure. 3. Group dance generation results of GCD in terms of different numbers of dancers.

3. Long-term Analysis

To evaluate the efficacy of the guidance signal in GCD for creating long-term group dance sequences, we conducted a comparative analysis between GCD and the baseline model GDanceR. The experiment involved musical pieces of different durations: 15 seconds, 30 seconds and 60 seconds. We show the results with the guidance parameter γ=0.5\gamma = 0.5 to enforce consistency with the music over a long duration. For more detailed information, please refer to Figure 5 and our supplementary video.

Figure. 5. Long-term results of the 60-second clip. For clearer visualization, please visit our demo video.

While both methods produce satisfactory results in the first few seconds of the animations (e.g., about 5-6 seconds), GDanceR starts to exhibit floating and unrealistic movements or freeze into a mean pose in the later period of the sequence. Figure 4 shows the Motion changes comparison between our GCD method and GdanceR. The motion change magnitudes are calculated as average differences of the kinetic features between consecutive frames. It is evident that the motion change magnitude of GDanceR is gradually lower and approaching zero in the later half of the 60-second music piece, whereas our method can preserve high magnitudes and variations over time. This is because GDanceR generates almost frozen dance choreographies during this period. In contrast, the group dance motions produced by GCD remain natural with diverse movements throughout the entire duration of all music samples.

Figure. 4. Motion changes comparison between our proposed GCD and GDanceR. The experiment is conducted on generated group dance results of 60-second music pieces..

These findings confirm that our approach can effectively address the problem of motion generation in long-horizon group dance scenarios. It maintains the motion quality and dynamics of the dance motions, ensuring that the created animations remain visually appealing throughout extended periods. This highlights the advantage of the contrastive strategy to enhance the consistency of the movements of dancers with their group and the music, resulting in significant improvements for long-term dance sequence generation compared to the baseline GDanceR.


In the next post, we will consider Ablation Studies and User Studies.

Controllable Group Choreography using Contrastive Diffusion (Part 4)

In previous part, we have discovered the theory of Music-Motion Transformer. In this part, let investigate the experiment setups for its effectiveness verification, as well as qualyty experimental results.

Our code can be found at: https://github.com/aioz-ai/GCD

1. Experiments

1.1 Implementation Details

The hidden layer of all the MLPs consists of 512512 units followed by GELU activation. The hidden dimension of all attention layers is set to d=512d=512, and the attention adopts a multi-head scheme with 88 attention heads. We also use a feature-wise linear modulation (FiLM) after each attention layer to strengthen the influence of the conditioning context. At the end of each attention block, we append a 22-layer feed-forward network with a feed-forward size of 10241024 to enhance the expressivity of the learned features. We extract the features from the raw audio signal by leveraging the representations from the frozen Jukebox, a pre-trained generative model for music, to enhance the model's generalization ability to several kinds of in-the-wild music. In total, the Group Diffusion Denoising Network is comprised of L=5L=5 stacked Music-Motion Transformer and Group Global Attention blocks, along with 22 transformer encoder layers to encode the music features. We implement the architecture of the Contrastive Encoder similarly to the Denoising Network but without cross attention since it does not take the music as input. The output sequence of the Contrastive Encoder is then averaged out and fed into an output layer with one unit. We also make the Contrastive Encoder aware of the current step in the diffusion chain by appending the diffusion timestep embedding to the motion sequence so that it can provide correct guidance signals in the sampling process. Overall, our model has approximately 62M62M trainable parameters.

1.2 Training

To train the denoising diffusion network, we use the "simple" objective as:

Lsimple=Ex0q(x0c),m[1,M][x0Gθ(xm,m,c)22]\mathcal{L}_{\rm simple} = \mathbb{E}_{x_0\sim q(x_0|c), m\sim [1,M]} \left[\Vert x_0 - \mathcal{G}_\theta(x_m,m,c) \Vert^2_2\right]

To improve the physical plausibility and prevent artifacts of the generated motion, we also utilize auxiliary geometric losses.

Lgeo=λposLpos+λvelLvel+λfootLfoot\mathcal{L}_{\rm geo} = \lambda_{\rm pos}\mathcal{L}_{\rm pos} + \lambda_{vel}\mathcal{L}_{\rm vel} + \lambda_{\rm foot}\mathcal{L}_{\rm foot}

In particular, geometric losses mainly consist of (i) a joint position loss Lpos\mathcal{L}_{\rm pos} to better constrain the global joint hierarchy via forward kinematics; (ii) a velocity loss Lvel\mathcal{L}_{\rm vel} to increase the smoothness and naturalness of the motion by penalizing the difference between the differences between the velocities of the ground-truth and predicted motions; and (iii) a foot contact loss Lfoot\mathcal{L}_{\rm foot} to mitigate foot skating artifacts and improve the realism of the generated motions by ensuring the feet to stay stationary when ground contact occurs.

Our total training objective is the combination of the "simple" diffusion objective, the auxiliary geometric losses, and the contrastive loss:

L=Lsimple+Lgeo+λnceLnce\mathcal{L} = \mathcal{L}_{\rm simple} + \mathcal{L}_{\rm geo} + \lambda_{\rm nce}\mathcal{L}_{\rm nce}

We train our model on 44 NVIDIA V100 GPUs using Adam optimizer with a learning rate of 1e41\mathrm{e}{-4} and a batch size of 6464 per GPU, which took about 77 days for 500500k iterations. The models are trained with M=1000M=1000 diffusion noising steps and a cosine noise schedule. During training, group dance motions are randomly sampled with sequence length T=150T=150 at 3030 Hz, which corresponds to 55-second pieces of music. For geometric losses, the loss weights are empirically set to λpos=1.0\lambda_{\rm pos}=1.0, λsmooth=1.0\lambda_{\rm smooth}=1.0, and λfoot=0.005\lambda_{\rm foot}=0.005, respectively. For the contrastive loss Lnce\mathcal{L}_{\rm nce}, its weight is λnce=0.001\lambda_{\rm nce} = 0.001, the probability of replacing dancers for negative sequences is 0.50.5, and the number of negative samples empirically is selected to 1010.

h~=S(w)hμ(h)σ(h)+b(w)\tilde{h} = S(w) * \frac{h-\mu(h)}{\sigma(h)}+ b(w)

where each channel of the whole activation sequence is first normalized separately by calculating the mean μ\mu and σ\sigma, and then scaled and biased using the outputted affine parameters S(w)S(w) and b(w)b(w). Intuitively, this operation shifts the activated hidden motion features of each individual motion towards a unified group representation to further encourage the association between them. Finally, the output features are then projected back to the original motion dimensions via a linear layer, to obtain the predicted outputs x^0\hat{x}_0.

1.3 Testing

At test time, we use the DDIM sampling technique with 5050 steps to accelerate the sampling speed of the reverse diffusion process. Accordingly, our model can achieve real-time generation at 3030 Hz on a single RTX 2080Ti GPU (excluding the music features extracting step), thanks to the parallelization of the Transformer architecture.

To enable long-term generation, we divide the input music sequence into multiple overlapping chunks, with each chunk having a maximum window size of 55 seconds and overlapped by half with the adjacent chunk. The group dance motions are then generated for each chunk along with the corresponding audio. Subsequently, we merge the outputs by blending the overlapped region between two consecutive chunks using spherical linear interpolation, with the interpolation weight gradually decaying from the current chunk to the next chunk. However, for group choreography synthesis, our model generates dance motions for each dancer in random order. Therefore, we need to establish correspondences between dancers across the chunks (i.e., identifying which one of the NN dancers in the next chunk corresponds to a dancer in the current chunk). To accomplish this, we organize all dancers in the current chunk into one set and the dancers in the next chunk into another set, forming a bipartite graph between the two chunks. We can then utilize the Hungarian algorithm to find the optimal matching, where the Euclidean distance between the two pose sequences serves as the matching weights. Our blending technique is applied at each step of the diffusion sampling process, starting from pure noise, thus it allows the model to gradually denoise the chunks to make them compatible for blending.

2. Experimental Settings

2.1 Dataset

We use AIOZ-GDance dataset in our experiments. AIOZ-GDance is a large-scale group dance dataset including paired music and 3D group motions captured from in-the-wild videos using a semi-automatic method, covering 7 dance styles and 16 music genres.

2.2 Evaluation Protocol

We use the following metrics to evaluate the quality of single dancing motion: Frechet Inception Distance (FID), Motion-Music Consistency (MMC), Generation Diversity (GenDiv), Physical Foot Contact score (PFC). Concretely, FID score measures the realism of individual dance movements against the ground-truth dance. The MMC evaluates the matching similarity between the motion and the music beats, i.e., how well generated dances follow the beat of the music. The generation diversity (GenDiv) is evaluated as the average pairwise distance of the kinetic features of the motions. The PFC evaluates the physical plausibility of the foot movements by calculating the agreement between the acceleration of the character's center of mass and the foot's velocity.

To evaluate the group dance quality, we follow three metrics: Group Motion Realism (GMR), Group Motion Correlation (GMC), and Trajectory Intersection Frequency (TIF).
In general, the GMR measures the realism between generated and ground-truth group motions by calculating Frechet Inception Distance on the extract group motion features. The GMC evaluates the synchrony between dancers within the generated group by calculating their cross-correlation. The TIF measures how often the generated dancers collide with each other in their dance movements.

2.3 Baselines

We compare our GCD method with several recent approaches on music-driven dance generation: FACT, Transflower, and EDGE, all of which are adapted for benchmarking in the context of group dance generation since the original methods were specifically designed for single-dance. We also evaluate against GDanceR, a recent model specifically designed for generating group choreography.

3. Quality Comparison

Table 1 shows a comparison among the baselines FACT, Transflower, EDGE, GDanceR, and our proposed GCD. The results clearly demonstrate that our default model setting with "neutral" mode outperforms the baselines significantly across all evaluations. We also observe that EDGE, a recent diffusion dance generation model, can yield very competitive performance on single-dance metrics (FID, MMC, GenDiv, and PFC). This suggests the advantages of diffusion approaches in motion generation tasks. However, it is still inferior to our model under several group dance metrics, showing the limitations of single dance methods in the context of group dance creation. Experimental results highlight the effectiveness of our approach in generating high-quality group dance motions.

Table. 1. Performance comparison. High Consistency: parameter γ=1\gamma=1; High Diversity: parameter γ=1\gamma=-1; Neutral: parameter γ=0\gamma=0

To complement the quantitative analysis, we present qualitative examples from FACT, GDanceR, and our GCD method in Figure 1. Notably, FACT struggles to deal with the intersection problem, which is reasonable given that it was not originally designed for group dance generation. As a result, the generated motions from FACT lack coordination and synchronization in most cases. While GDanceR shows improvements in terms of motion quality compared to FACT, the generated motions appear floating, unnatural, and sometimes unsynchronized in many cases. These drawbacks indicate that GDanceR's effort on generating group choreography would still require more refinement to produce consistent and cohesive movements among the dancers.

Figure. 1 Comparison between different dance generation methods when generating dancing in groups.

In contrast, our method excels in both controlling consistency and promoting diversity among the generated group dance motions. The outputs from our method demonstrate well-coordinated and realistic movements, implying that it can resolve the challenges of maintaining group coherence while delivering visually appealing results more effectively.

Overall, the conducted quantitative analysis and visual comparisons reaffirm the superior performance of our proposed GCD to generate high-quality, synchronized, and visually pleasing group dance motions.


In the next post, we will mention our Consistency and Diversity Analysis.

Controllable Group Choreography using Contrastive Diffusion (Part 3)

In previous part, we have discovered Music-Motion Transformer. In this part, let investigate the Group Modulation and Contrastive Divergence Loss.

Our code can be found at: https://github.com/aioz-ai/GCD

1. Group Modulation

To better apply the group information constraints to the learned hidden features of the dancers, we adopt a Group Modulation layer that learns to adaptively influence the output of the transformer attention block by applying an affine transformation to the intermediate features based on the group embedding ww. More specifically, we utilize two separate linear layers to learn the affine transformation parameters {S(w);b(w)}Rd\{S(w); b(w)\} \in \mathbb{R}^d from the group embedding ww. The predicted affine parameters are then used to modulate the activations sequence h={h11hT1;;h1NhTN}h = \{h^1_1\dots h^1_T;\dots;h^N_1\dots h^N_T\} as follows:

h~=S(w)hμ(h)σ(h)+b(w)\tilde{h} = S(w) * \frac{h-\mu(h)}{\sigma(h)}+ b(w)

where each channel of the whole activation sequence is first normalized separately by calculating the mean μ\mu and σ\sigma, and then scaled and biased using the outputted affine parameters S(w)S(w) and b(w)b(w). Intuitively, this operation shifts the activated hidden motion features of each individual motion towards a unified group representation to further encourage the association between them. Finally, the output features are then projected back to the original motion dimensions via a linear layer, to obtain the predicted outputs x^0\hat{x}_0.

2.Contrastive Diffusion

We learn the representations that encode the underlying shared information between the group embedding information ww and the group sequence xx. Specifically, we model a density ratio that preserves the mutual information between xx and ww as:

f(x,w)p(xw)p(x)f(x, w) \propto \frac{p({x}|{w})}{p({x})}

f()f(\cdot) is a model (i.e., a neural network) to predict a positive score (how well xx is related to ww) for a pair of (x,w)({x}, {w}).

To enhance the association between the generated group dance (data) and the group embedding (context), we aim to maximize their mutual information with a Contrastive Encoder f(x^,w)f(\hat{x},w) via the contrastive learning objective as in Equation below. The encoder takes both the generated group dance sequence x^\hat{x} and a group embedding ww as inputs, and it outputs a score indicating the correspondence between these two.

Lnce=E[logf(x^,w)f(x^,w)+ΣxjXf(x^j,w)]\mathcal{L}_{\rm nce} = - \mathbb{E} \left[ \log\frac{f(\hat{x},w)}{f(\hat{x},w) + \Sigma_{x^j \in X'}f(\hat{x}^j, w)}\right]

where XX' is a set of randomly constructed negative sequences. In general, this loss is similar to the cross-entropy loss for classifying the positive sample, and optimizing it leads to the maximization of the mutual information between the learned context representation and the data. Using the contrastive objective, we expect the Contrastive Encoder to learn to distinguish between the two quantities: consistency (the positive sequence) and diversity (the negative sequence). This is the key factor that enables the ability to control diversity and consistency in our framework.

Here, we will describe our strategy to construct contrastive samples to achieve our target. Recall that we use reverse distribution pθ(xt1xt)p_\theta(x_{t-1} | x_t) of Gaussian Diffusion with the mean as the prediction of the model while the variance is fixed to a scheduler (Equation~\ref{eq:approximateposterior}). To obtain the contrastive samples, given the true pair is (x0,w)(x_0,w), we first leverage forward diffusion process q(xmx0)q(x_m|x_0) to obtain the noised sample xmx_m. Then, our positive sample is $\theta(x{m-1} |xm, w).Subsequently,weconstructthenegativesamplefromthepositivepairbyrandomlyreplacingdancersfromothergroupdancesequences(. Subsequently, we construct the negative sample from the positive pair by randomly replacing dancers from other group dance sequences (x^j_0 \neq x_0) with some probabilities, feeding it through the forward process to obtain $x^j_m, then our negative sample is $\theta(x^j{m-1} | x^j_m, w). By constructing contrastive samples this way, the positive pair $(x_0,w) represents a group sequence with high consistency, whereas the negative one represents a high diversity sample. This is because mixing a sample with dancers from different groups is likely to result in substantially distinctive movements between each dancer, making it a group dance sample with high degree of diversity. Note that negative sequences should also match the music because they are motions generated by the network whose inputs are manipulated to increase diversity. Particularly, negative samples are acquired from outputs of the denoising network whose inputs are both the current music and the noised mixed group with some replaced dancers. As the network is trained to reconstruct only positive samples, its outputs will likely follow the music. Therefore, negative samples are not just random bad samples but are the valid group dance generated from the network that is trained to generate group dance conditioned on the music. This is because our main diffusion training objective is calculated only for ground-truth dances (positive samples) that are consistent with the music. Our proposed strategy also allows us to learn a more powerful group representation as it directly affects the reverse process, which is beneficial to maintaining consistency in long-term synthesis.

3. Diversity vs. Consistency

Using the Contrastive Encoder f(xm,w)f(x_m,w), we extend the classifier guidance to control the generation process. Accordingly, we incorporate f(xm,w)f(x_m,w) in the contrastive framework to replace the guiding classifier in the original formula, since it provides a score of how consistent the sample is with the group information. In particular, we shift the mean of the reverse diffusion process with the log gradient of the Contrastive Encoder with respect to the generated data as follows:

μ^θ(xm,m)=μθ(xm,m)+γΣθ(xm,m)xmlogf(xm,w)\hat{\mu}_\theta(x_m,m) = \mu_\theta(x_m,m) + \gamma \cdot \Sigma_{\theta}(x_m,m) \nabla_{x_m}\log f(x_m,w)

where γ\gamma is the control parameter that uses the encoder to enforce consistency and connection with the group embedding. Since the Contrastive Encoder is trained to classify between high-consistency and high-diversity samples, its gradients yield meaningful guidance signals to control the trade-off process. Intuitively, a positive value of γ\gamma encourages more consistency between dancers while a negative value (which corresponds to shifting the distribution with a negative gradient step) boosts the diversity between each individual dancer.


In the next post, we will mention Experimental Setups and Analysis.

Controllable Group Choreography using Contrastive Diffusion (Part 2)

In previous part, we have discover the motivation and recent works that focus on generating dances and group dances. We also focus on analyze the pros and cons of each methods. We then dive into the detailed algorithm of our proposed GCD. Our methodology can be splited into three parts: Music-Motion Transformer and Group Global Attention, and Contrastive Diffusion Loss for Controllable Motions. In this part, let investigate the Music-Motion Transformer first.

Our code can be found at: https://github.com/aioz-ai/GCD

1. Group Diffusion Denoising Network

Our model architecture is illustrated in Figure.1. We utilize a transformer based architecture to generate the whole sequence in one go. Compared with recent auto-regressive approach~\cite{le2023music}, our method does not suffer from the error accumulation problem (i.e., the prediction error accumulates over time since the current-frame outputs are used as inputs to the next frame in the auto-regressive fashion) and thus can generate arbitrary long motion dance sequences without freezing effects. The input of our network at each diffusion step mm is the noisy group sequence xm={xm,11,...,xm,T1;...;xm,1N,...,xm,TN}x_m =\{x^1_{m,1},..., x^1_{m,T}; ...;x^N_{m,1},...,x^N_{m,T}\} however, we skip the mm index for ease of notation from now on.

2. Music-Motion Transformer:

Given an input extracted audio sequence a={a1,a2,...,aT}a = \{a_1, a_2, ...,a_T\}, we employ a transformer encoder architecture to encode the music into the sequence of hidden audio representation {c1,c2,...,cT}\{c_1, c_2, ...,c_T\}, which will be used as the conditioning context to the diffusion denoising network. Specifically, we follow the encoder layer which consists of multi-head self-attention layers and feed-forward layers to effectively encode the multi-scale rhythmic patterns and long-term dependencies between music frames. The diffusion time-step mm is also projected to the transformer dimension through a separate Multi-layer Perceptron (MLP) with 3 hidden layers to get the embedding τemb\tau_{emb}, then concatenated with the music feature sequence to obtain the final conditioning context c={c1,c2,...,cT,τemb}c = \{c_1, c_2, ...,c_T, \tau_{emb}\}.

Figure 1. Detailed illustration of our method for group choreography generation.

Although group choreography incorporates the problem of learning the interaction between dancers, we still need to learn the correlation between the dance movements and the accompanying music audio for each dancer. Therefore, we design the Music-Motion Transformer to essentially focus on learning the direct connection between the motion and the music of each individual dancer (and not considering the interconnection among dancers yet). Each frame of the noised input motion xtix^i_t is projected into the transformer dimension by a linear layer followed by an additive positional encoding. Given the whole group sequence including all dancers {x11,...,xT1;...;x1n,...,xTn}\{x^1_1,..., x^1_T; ...;x^n_1,...,x^n_T\}, we separately encode the motion features of each individual dancer by utilizing the multi-head self-attention with masking strategy. We implement the masked self-attention (MSA) mechanism as follows:

MSA(Q,K,V)=softmax(QKdk+mlocal)V,Q=xWQ,K=xWK,V=xWV\text{MSA}(Q,K,V) = \text{softmax}\left(\frac{{Q}{K}^\top}{\sqrt{d_k}} + {m}_{local} \right) {V}, \\ {Q} = x {W}^{Q}, \quad {K} = x {W}^{K}, \quad {V} = x {W}^{V}

where WQ,WKRd×dk{W}^{Q}, {W}^{K} \in \mathbb{R}^{d \times d_k} and WVRd×dv{W}^{V} \in \mathbb{R}^{d \times d_v} are learnable projection matrices to transform the input to query, key, and value, respectively. mlocal{m}_{local} is the local attention mask illustrated in Figure2.a. This mask ensures each individual can only attend to their own motion. Subsequently, to incorporate the music conditioning context cc into each individual motion features, we adopt a transformer decoder architecture with cross-attention mechanism (CA), where the motion is the query and the music is the key/value.

CA(Q~,K~,V~)=softmax(Q~K~dk)V~,Q~=x~W~Q,K~=cW~K,V~=cW~V\text{CA}(\tilde{Q}, \tilde{K}, \tilde{V}) = \text{softmax}\left(\frac{{\tilde{Q}}{\tilde{K}}^\top}{\sqrt{d_k}} \right) {\tilde{V}}, \\ {\tilde{Q}} = \tilde{x} {\tilde{W}}^{Q}, \quad {\tilde{K}} = c {\tilde{W}}^{K}, \quad {\tilde{V}} = c {\tilde{W}}^{V}

where x~\tilde{x} is the output activation of the MSA block, and W~Q{\tilde{W}}^{Q}, W~K{\tilde{W}}^{K}, W~V{\tilde{W}}^{V} are the learnable projection matrices that have similar behavior to the MSA mechanism.

3. Group Global Attention

To ensure the coherency and non-collision in the movements of all dancers within the group, such that their dances should correlate with each other under the music condition instead of dancing asynchronously , we first perform global attention via a masked attention mechanism with a full masking strategy mglobalm_{global}. The attention mask is illustrated in Figure2.b. It allows a dancer to fully attend to all other dancers under the global receptive field. Then, we propose the Group Modulation to enforce the group constraints within the group embedding information.

Figure 2. The local attention mask mlocalm_{local} (a) and global attention mask mglobalm_{global} (b). The blue cell indicates where frames can attend to each other. Blue color represents zero value of the mask while gray color represents minus infinity. x[1:T]ix^i_{[1:T]} indicate motion sequence of ii-th dancer.

Since the synthesized image can be manipulated via a latent style vector, we aim to learn a group embedding information from the input music in order to control the group dance generation process. We first apply temporal average pooling to the encoded music feature sequence to obtain a compact representation of the input music cˉ=1Tt=1Tct\bar{c} = \frac{1}{T}\sum_{t=1}^T c_t. To increase the variation and diversity of the group information (i.e., avoid limiting the group embedding to only one style of the input music), we inject a random noise drawn from a standard gaussian distribution zN(0,I)z \sim \mathcal{N}(0,I) into cˉ\bar{c}. We use an 8-layer MLP to learn a mapping from the audio representation to the group embedding. We also add a learnable embedding token ene_{n} from a variable-size lookup table ERN×DE \in \mathbb{R}^{N\times D} up to NN maximum dancers, to represent the variation of dancers in the sequence since each sequence may contain different number of dancers. In summary, the process can be written as follows:

w=MLP(z+1Tt=1Tct)+en,zN(0,I)w = \text{MLP}\left(z + \frac{1}{T}\sum_{t=1}^T c_t\right) + e_{n}, \quad z \sim \mathcal{N}(0,I)


In the next post, we will mention Group Global Attention.

Controllable Group Choreography using Contrastive Diffusion (Part 1)

Music-driven group choreography poses a considerable challenge but holds significant potential for a wide range of industrial applications. The ability to generate synchronized and visually appealing group dance motions that are aligned with music opens up opportunities in many fields such as entertainment, advertising, and virtual performances. However, most of the recent works are not able to generate high-fidelity long-term motions, or fail to enable controllable experience. In this work, we aim to address the demand for high-quality and customizable group dance generation by effectively governing the consistency and diversity of group choreographies. In particular, we utilize a diffusion-based generative approach to enable the synthesis of flexible {number of dancers} and long-term group dances, while ensuring coherence to the input music. Ultimately, we introduce a Group Contrastive Diffusion (GCD) strategy to enhance the connection between dancers and their group, presenting the ability to control the consistency or diversity level of the synthesized group animation via the classifier-guidance sampling technique. Through intensive experiments and evaluation, we demonstrate the effectiveness of our approach in producing visually captivating and consistent group dance motions. The experimental results show the capability of our method to achieve the desired levels of consistency and diversity, while maintaining the overall quality of the generated group choreography.

Our code can be found at: https://github.com/aioz-ai/GCD

1. Introduction

With the widespread presence of digital social media platforms, the act of creating and editing dance videos has gained immense popularity among social communities. This surge in interest has resulted in the daily production and watching of millions of dancing videos across online platforms. Recently, researchers from computer vision, computer graphics, and machine learning communities have devoted considerable attention to developing techniques that can generate natural dance movements from music. These advancements have far-reaching implications and find applications in various domains, such as animation, the creation of virtual idols, the development of virtual meta-verse, and dance education. These techniques empower artists, animators, and educators alike, providing them with powerful tools to enhance their creative endeavors and enrich the dance experience for both performers and audiences.

While significant progress has been made in generating dancing motions for single dancer, the task of producing cohesive and expressive choreography for a group of dancers has received limited attention. The generation of synchronized group dance motions that are both realistic and aligned with music remains a challenging problem in the field of computer animation and motion synthesis. This is primarily due to the complex relationship between music and human motion, the diverse range of motions required for group performances, and the insufficient of a suitable dataset. At present, AIOZ-GDance stands as the most recent extensive dataset available to facilitate the task of generating group choreography. Besides, while current algorithms can generate individual movements and choreographic sequences, ensuring that these elements align seamlessly with the overall group performance is also paramount.

Different from solo dance, group dance involves coordination and interaction between dancers, making it crucial and challenging to establish correlations between motion series within a group. Besides, group dance can involve complex and diverse choreographies among participating dancers while still maintaining a semantic relationship between the motion and input music. Exploring the consistency and diversity between the movements of dancers of the synthesized group choreography is of vital importance to create a natural and expressive performance. The ability to control the consistency and diversity in group dance generation holds great potential across various applications. One such application is in the realm of entertainment and performance. Choreographers and creative teams can leverage this control ability to design captivating group dance routines that seamlessly blend synchronized movements with moments of individual expression. Second, in the context of animation and virtual metaverses, the control over consistency and diversity allows for the creation of visually stunning and immersive virtual dance performances. By balancing the synchronization of dancers' movements, while also introducing variations and unique flourishes, the generated group dances can captivate audiences and evoke a sense of realism and authenticity. Last but not least, in dance education and training, the ability to regulate consistency and diversity in group dance generation can be invaluable. It enables instructors to provide students with a diverse range of generated dance routines and samples that challenge their abilities, promote collaboration, and foster creativity. By dynamically adjusting the level of consistency and diversity, educators can cater to the unique need and skill level of each individual dancer, creating more inclusive instructions and enriching the learning environment. Although plenty of applications can be listed, due to some limitations of data establishment, investigating the consistency and diversity in group choreography has not been carefully explored.

In this paper, our goal is to develop a controllable technique for group dance generation. We present a Group Contrastive Diffusion (GCD) strategy that learns an encoder to capture the key targets between group dance movements. Diffusion modeling provides a flexible framework for manipulating the dance distribution, which allows us to modulate the degree of diversity and consistency in the generated dances. By using denoising diffusion probabilistic model as a key technique, we can effectively control the trade-off between diversity and consistency during the group dance generation, thanks to the guided sampling process. With this approach, we can guide the generation process toward a desired balance between diversity and consistency levels. Moreover, incorporating the encoder, which learns the association between the dancers and their group, can help to maintain the generated dance moves so that they are consistent with a specific dance style, music genre, or any long-term chorus. We empirically show that this approach has the potential to enhance the quality and naturalness of generated group dance performances, making it more appealing for various applications.

Figure 1. We present a contrastive diffusion method that controls the consistency (top row) and diversity (second row) in group choreography.

2. Overview

Music-driven Choreography. Creating natural and authentic human choreography from music is a complex task. One commonly employed technique involves using a motion graph derived from a vast motion database to generate new motions. This involves combining various motion segments and optimizing transition costs along the graph path. Alternatively, there are other methods that incorporate music-motion similarity matching constraints to ensure consistency between the motion and the accompanying music. Previous studies have extensively explored these methodologies. However, most of these approaches relied on heuristic algorithms to stitch together pre-existing dance segments sourced from a limited music-dance database. While these methods are successful in generating extended and realistic dance sequences, they face limitations when trying to create entirely novel dance fragments.

In recent years, several signs of progress have been made in the field of music-to-dance motion generation using Convolutional Network (CNN), Recurrent Network (RNN), Graph Neural Network (GNN), Generative Adversarial Network (GAN), or Transformer. Typically, these methods rely on multiple inputs such as the current music and a brief history of past dance movements to predict the future sequence of human poses. Recently, Gong et. al. propose an interesting task of generating dance by simultaneously utilizing both music and text instruction. A music-text feature fusion module was designed to fuse the inputs into a motion decoder to generate dance conditioned on both music and text. However, although these methods have the potential to produce natural and realistic dancing motion, they are often unable to create synchronized and harmonious movements between multiple dancers. Ensuring coordination and synchronization between dancers is a complicated problem, as it involves not only individual pose predictions but also the seamless integration of these poses within the context of a group. Achieving synchronized and harmonious group movements requires considering spatial and temporal relationships among dancers, their interactions, and the overall choreographic structure. Thus, further advances in the field are considered to address these challenges, including works that use deep learning approaches such as Variational Autoencoder (VAE), GAN, and Normalising Flow.

Unfortunately, most of these networks are limited by their ability to model long-term dance sequences ce may freeze or drift towards the end of the music. To facilitate long-term generation, many researchers apply a motion repeat constraint to predict future frames by attending to the historical motions. Nevertheless, this would limit the flexibility of the model by forcing it to always look into the past.

Group Choreography Group choreography and its related problem, multi-person motion prediction, have been an active research area with numerous studies addressing the challenges of predicting the behaviors of multiple individuals. Alahi et.al. utilize a Markov chain model to jointly analyze the trajectories of several pedestrians and predict their destinations in a given scene. Adel et. al. integrate social interactions and the visual context of the environment to forecast the future motion of multiple individuals. Multi-Range Transformers has proposed to predict the movements of groups with more than ten people engaging in social interactions. These aforementioned methods leverage various techniques to capture social interaction, spatial dependencies, and temporal dynamics, generally aiming to predict accurate and socially plausible future motions for multiple individuals in different scenarios. However, despite the notable advancements achieved, there remains a demand for further investigation of the correlation between the consistency and diversity of motions within group context. A deeper understanding of how to attain the optimal balance between consistency and diversity holds the potential to unlock new possibilities for creating group choreographies that benefit the users in many circumstances.

Diffusion for Music-driven Choreography.Recently, diffusion-based approaches have shown remarkable results on several generative tasks ranging from image generation, audio synthesis, pose estimation, natural language generation, and motion synthesis, to point cloud generation, 3D object synthesis, and scene creation. Diffusion models have shown that they can achieve high mode coverage, unlike GANs, while still maintaining high sample quality. Most existing diffusion-based approaches for human motion/dance synthesis only focus on generating motion sequences for a single character, conditioned on information such as text, audio, or both audio and text. Different from these prior works, we aim to create group of dancing motions from music, which includes coordinating multiple characters, avoiding collisions, and maintaining coherence between them. In addition to the vanilla diffusion loss term used for training in previous works, our method employs a contrastive learning strategy that directly influences the training of the diffusion reverse process, enhancing the association within the dance group.

A prominent issue of the diffusion approach for motion synthesis is that although it is highly effective in generating diverse samples, injecting a large amount of noise during the sampling process can lead to inconsistent results. This issue is particularly problematic for group dance paradigm. Therefore, our desideratum is to design a group dance generation model with the ability to address both diversity and consistency problem.

3. Preliminaries

3.1 Background

Given an input music sequence {a1,at,...,aT}\{a_1, a_t, ...,a_T\} with t={1,...,T}t = \{1,..., T\} indicates the index of the music frames, our goal is to generate the group motion sequences of NN dancers: {x11,...,xT1;...;x1N,...,xTN}\{x^1_1,..., x^1_T; ...;x^N_1,...,x^N_T\} where xtix^i_t is the pose of ii-th dancer at frame tt. We represent dances as sequences of poses in the 24-joint of the SMPL model, using the 6D continuous rotation for every joint, along with a single 3D root translation. This rotation representation ensures the uniqueness and continuity of the rotation vector, which is more beneficial to the training of deep neural networks. We tackle the group dance generation task by using a diffusion-based framework to synthesize the motions from a random noise distribution, given the music conditioning. Thanks to the sampling process of the diffusion model, we can effectively control the consistency and diversity in the generated sequences.

3.2 Forward Process of Diffusion Model.

Given an original sample from the real data distribution x0q(x0){x_0} \sim q({x_0}), following \cite{ho2020ddpm}, the forward diffusion process is defined as a Markov process that gradually adds Gaussian noise to the data under a pre-defined noise schedule up to MM steps.

q(xmxm1)=N(xm;1βmxm1,βmI),m{1,2,...,M}q(x_m | x_{m-1}) = \mathcal{N}(x_m; \sqrt{1-\beta_m} {x_{m-1}}, \beta_mI), \forall m \in \{1,2, ... ,M\}

If the noise variance schedule βm\beta_m is small and the number of diffusion step MM is large enough, the distribution q(xM)q(x_M) at the end of the process is well-approximated by a standard normal distribution N(0,I)\mathcal{N}(0,I), which is easy to sample from. Thanks to the nice property of the forward diffusion, we can directly obtain the noised sample at any arbitrary step mm without traversing through the whole chain:

q(xmx0)=N(xm;αˉmx0,(1αˉm)I),xm=αˉmx0+1αˉmϵ,ϵN(0,I)q(x_m | x_{0}) = \mathcal{N}(x_m; \sqrt{\bar{\alpha}_m} x_{0} , (1-\bar{\alpha}_m)I), x_m = \sqrt{\bar{\alpha}_m} x_{0} + \sqrt{1-\bar{\alpha}_m} \epsilon, \epsilon \sim \mathcal{N}(0, I)

where αm=1βm\alpha_m = 1-\beta_m and αˉm=s=0mαs\bar{\alpha}_m = \prod_{s=0}^m \alpha_s.

3.3 Reverse Process.

By additionally conditioning on x0x_0, the posterior of the reverse process is tractable and becomes a Gaussian distribution:

q(xm1xm,x0)=N(xm1;μ~m,β~mI),q(x_{m-1} | x_m, x_0) = \mathcal{N} (x_{m-1}; \tilde{\mu}_m, \tilde{\beta}_mI),

where μ~m\tilde{\mu}_m and β~m\tilde{\beta}_m are the posterior mean and variance that depend on both xmx_m and x0x_0, respectively. To obtain a sample from the original data distribution, we start by sampling from the noise distribution q(xM)q(x_M) and then gradually remove the noise until we reach x0x_0, following the reverse process. Therefore, our goal is to train a neural network to approximate the posterior q(xm1xm)q(x_{m-1} | x_m) of the reverse process as:

pθ(xm1xm)=N(xm1;μθ(xm,m),Σθ(xm,m))p_{\theta}(x_{m-1} | x_m) = \mathcal{N}(x_{m-1}; \mu_{\theta}(x_m, m), \Sigma_{\theta}(x_m,m) )

We model only the mean μθ(xm,m)\mu_{\theta}(x_m, m) of the reverse distribution while keeping the variance Σθ(xm,m)\Sigma_{\theta}(x_m,m) fixed according to the noise schedule. However, instead of predicting the noise ϵm\epsilon_m at any arbitrary step mm as in their approach, we train the network to learn to predict the original noiseless signal x0x_0. The sample at the previous step m1m-1 can be obtained by noising back the predicted x0x_0. For conditional generation setting, the network is additionally conditioned with the conditioning signal cc as x0Gθ(xm,m,c)x_0 \approx \mathcal{G}_\theta(x_m,m,c) with model parameters θ\theta.


In the next post, we will mention our proposal in details.

Reducing Training Time in Cross-Silo Federated Learning using Multigraph Topology (Part 4)

In previous post, we have mentioned multigraph parsing proccess, how to train a multigraph under decentralized federated learning, and experimental setups for Multigraph. In the post, we will mention the effectiveness and efficiency of multigraph topology design under different configurations.

Our code can be found at: https://github.com/aioz-ai/MultigraphFL

1. Cycle Time Comparison

Table 1 shows the cycle time of our method in comparison with other recent approaches. This table illustrates that our proposed method significantly reduces the cycle time in all setups with different networks and datasets. In particular, compared to the state-of-the-art RING, our method reduces the cycle time by 2.182.18, 1.51.5, 1.741.74 times in average in the FEMNIST, iNaturalist, Sentiment140 dataset, respectively. Our method also clearly outperforms MACHA, MACHA(+), and MST by a large margin. The results confirm that our multigraph with isolated nodes helps reduce the cycle time in federated learning.

From Table1, our multigraph achieves the minimum improvement under the Amazon network in all three datasets. This can be explained that, under the Amazon network, our proposed topology does not generate many isolated nodes. Hence, the improvement is limited. Intuitively, when there are no isolated nodes, our multigraph will become the overlay, and the cycle time of our multigraph will be equal to the cycle time of the overlay in RING.


2. Isolated Node Analysis

Isolated Nodes vs. Network Configuration. The numbers of isolated nodes vary based on the network configuration (Amazon, Gaia, Exodus, etc.). The parameter tt (maximum number of edges between two nodes), and the delay time which is identified by many factors (geometry distance, model size, computational cost based on tasks, bandwidth, etc also affect the process of generating isolated nodes. Table 2 illustrates the effectiveness of isolated nodes in different network configurations. Specifically, we conduct experiments on the FEMNIST dataset using five network configurations (Gaia, Amazon, Geant, Exodus, Ebone). We can see that our cycle time compared with RING is reduced significantly when more communication rounds or graph states have isolated nodes. Tab-2

Table 2: The effectiveness of isolated nodes under different network configurations. All experiments are trained with 6400 communication rounds on FEMNIST dataset.We then record the number of states and rounds that have the appearance of isolated nodes and compare our cycle time with RING.

Isolated Nodes vs. RING vs. Random Strategy. Isolated nodes play an important role in our method as we can skip the model aggregation step in the isolated nodes. In practice, we can have a trivial solution to create isolated nodes by randomly removing some nodes from the overlay of RING. Table 3 shows the experiment results in two scenarios on FEMNIST dataset and Exodus Network: i} Randomly remove some silos in the overlay of RING, and ii} Remove the most inefficient silos (i.e., silos with the longest delay) in the overlay of RING. From Table 3, the cycle time reduces significantly when two aforementioned scenarios are applied. However, the accuracy of the model also drops significantly. This experiment shows that although randomly removing some nodes from the overlay of RING is a trivial solution, it can not maintain model accuracy. On the other hand, our multigraph not only reduces the cycle time of the model but also preserves the accuracy. This is because our multigraph can skip the aggregation step of the isolated nodes in a communication round. However, in the next round, the delay time of these isolated nodes will be updated, and they can become normal nodes and contribute to the final model.


Table 3: The cycle time and accuracy of our multigraph vs. RING with different criteria.

Isolated Nodes Illustration. Figure belows shows a detailed illustration of our algorithm with the isolated nodes in a real-world training scenario. The experiment is conducted on Gaia network geometry and their corresponding hardware for supporting link latency computation. The image classification task is chosen for this benchmarking by using FEMNIST dataset and CNN backbone provided by Marfod \etal. Hence, we keep the model transmitted size at 4.624.62 Mb, all access links have 1010 Gbps traffic capacity, the number of local updates is set to 11, and the maximum number of edges tt is set to 33. As shown in this Figure, although there are no isolated nodes in the initialized state, the number of isolated nodes increases in the next consequence states with a vast number (4 nodes). This circumstance leads to a 4\sim 4 times reduction in cycle time compared to the initialized state. The appearance of isolated nodes also greatly reduces the connection between silos by 3.6\sim 3.6 times, from 1111 down to 33 connections, and also discarded ones all have high latency.


3. Multigraph Ablation Study

Accuracy Analysis. In federated learning, improving the model accuracy is not the main focus of topology designing methods. However, preserving the accuracy is also important to ensure model convergence. Table 4 shows the accuracy of different topologies after 6,4006,400 communication training rounds on the FEMNIST dataset. This table illustrates that our proposed method achieves competitive accuracy with other topology designs. This confirms that our topology can maintain the accuracy of the model, while significantly reducing the training time.


Table 4: Accuracy comparison between different topologies. The experiment is conducted using the FEMNIST dataset. The accuracy is reported after 6,4006,400 communication rounds in all methods.

Convergence Analysis. Figure belows shows the training loss versus the number of communication rounds and the wall-clock time under Exodus network using the FEMNIST dataset. This figure illustrates that our proposed topology converges faster than other methods while maintaining the model accuracy. We observe the same results in other datasets and network setups.

Cycle Time and Accuracy Trade-off. In our method, the maximum number of edges between two nodes tt mainly affects the number of isolated nodes. This leads to a trade-off between the model accuracy and cycle time. Table 5 illustrates the effectiveness of this parameter. When t=1t = 1, we technically consider there are no weak connections and isolated nodes. Therefore, our method uses the original overlay from RING. When tt is set higher, we can increase the number of isolated nodes, hence decreasing the cycle time. In practice, too many isolated nodes will limit the model weights to be exchanged between silos. Therefore, models at isolated nodes are biased to their local data and consequently affect the final accuracy.


Table 5: Cycle time and accuracy trade-off with different value of tt, i.e., the maximum number of edges between two nodes.

4. Conclusion

We proposed a new multigraph topology for cross-silo federated learning. Our method first constructs the multigraph using the overlay. Different graph states are then parsed from the multigraph and used in each communication round. Our method significantly reduces the cycle time by allowing the isolated nodes in the multigraph to do model aggregation without waiting for other nodes. The intensive experiments on three datasets show that our proposed topology achieves new state-of-the-art results in all network and dataset setups.

Reducing Training Time in Cross-Silo Federated Learning using Multigraph Topology (Part 3)

In previous part, we have investigated that how delay time and cycle time is affected by the modification of multigraph in the design of the topology. Also, we will explore how multigraph can be constructed. In this post, we will mention multigraph parsing proccess, how to train a multigraph under decentralized federated learning, and experimental setups for Multigraph.

Our code can be found at: https://github.com/aioz-ai/MultigraphFL

1. Multigraph Parsing

In Algorithm~\ref{alg:state_form}, we parse multigraph Gm\mathcal{G}_m into multiple graph states Gms\mathcal{G}_m^s. Graph states are essential to identify the connection status of silos in a specific communication round to perform model aggregation. In each graph state, our goal is to identify the isolated nodes. During the training, isolated nodes update their weights internally and ignore all weakly-connected edges that connect to them.

To parse the multigraph into graph states, we first identify the maximum of states in a multigraph smaxs_{\max} by using the least common multiple (LCM). We then parse the multigraph into smaxs_{\max} states. The first state is always the overlay since we want to make sure all silos have a reliable topology at the beginning to ease the training. The reminding states are parsed so there is only one connection between two nodes. Using our algorithm, some states will contain isolated nodes. During the training process, only one graph state is used in a communication round. Figure below illustrates the training process in each communication round using multiple graph states.

2. Multigraph Training

In each communication round, a state graph Gms\mathcal{G}_m^s is selected in a sequence that identifies the topology design used for training. We then collect all strongly-connected edges in the graph state Gms\mathcal{G}_m^s in such a way that nodes with strongly-connected edges need to wait for neighbors, while the isolated ones can update their models. We train our multigraph with DPASGD algorithm:

wi(k+1)={jNi++{i}Ai,jwj(kh),k0&Ni++>1,wi(k)αk1bh=1bLi(wi(k),ξi(h)(k)),otherwise.w_{i}\left(k + 1\right) = \begin{cases} \sum_{j \in \mathcal{N}_{i}^{++} \cup{\{i\}}}A_{i,j}w_{j}\left(k - h\right), \forall k \equiv 0 \& \left|\mathcal{N}_{i}^{++}\right| > 1 ,\\ w_{i}\left(k\right)-\alpha_{k}\frac{1}{b}\sum^b_{h=1}\nabla L_i\left(w_{i}\left(k\right),\xi_i^{\left(h\right)}\left(k\right)\right), otherwise. \end{cases}

where (kh)(k- h) is the index of the considered weights; hh is initialized to 00 and h=h+1ekh(i,j)=0h = h + 1 \forall e_{k-h}(i,j) = 0. Through Equation above, at each state, if a silo is not an isolated node, it must wait for the model from its neighbor to update its weight. If a silo is an isolated node, it can use the model in its neighbor from the (kh)(k-h) round to update its weight immediately. The training procedure is described as below:

3. Algorithm Complexity

It is trivial to see that the complexity of training procedure is O(n2)\mathcal{O}(n^2). In practice, since the cross-silo federated learning setting has only a few hundred silos (n<500n<500), the time to execute our algorithms is just a tiny fraction of training time. Therefore, our proposed topology still can significantly reduce the overall wall-clock training time.

4. Experimental Setups

Datasets. We use three datasets in our experiments: Sentiment140, iNaturalist, and FEMNIST. All datasets and the pre-processing process are conducted by following recent works. Table below shows the dataset setups in detail.

Network. We consider five distributed networks in our experiments: Exodus, Ebone, Géant, Amazon and Gaia. The Exodus, Ebone, and Géant are from the Internet Topology Zoo. The Amazon and Gaia network are synthetic and are constructed using the geographical locations of the data centers.

Baselines. We compare our multigraph topology with recent state-of-the-art topology designs for federated learning: STAR, MATCHA, MATCHA(+), MST, and RING.

Hardware Setup. Since measuring the cycle time is crucial to compare the effectiveness of all topologies in practice, we test and report the cycle time of all baselines and our method on the same NVIDIA Tesla P100 16Gb GPUs. No overclocking is used.

Time Simulator. We adapted PyTorch with the MPI backend to run DPASGD and DPASGD++ on a GPU cluster. We take advantage of the network simulator, the Time Simulator, which uses an arbitrary topology and computation times of silos as input to calculate the time instants at which local models are computed. The wall-clock time is reconstructed by this time simulator needs thorough understanding of the topology, including all factors mentioned in Delay Equations in each network configuration. The related configuration information is already provided in GAIA Network, and the simulator is created by Marfod \etal.


In the next post, we will mention the effectiveness and efficiency of multigraph topology design under different configurations.

Reducing Training Time in Cross-Silo Federated Learning using Multigraph Topology (Part 2)

In previous paart, we how explore decentralized federated learning, how and why multigraph is proposed to improve training process. In this part, we will investigate that how delay time and cycle time is affected by the modification of multigraph in the design of the topology. Also, we will explore how multigraph can be constructed.

Our code can be found at: https://github.com/aioz-ai/MultigraphFL

1. Delay time in multigraph

A delay to an edge e(i,j)e(i, j) is the time interval when node jj receives the weight sending by node ii, which can be defined by:

d(i,j)=u×Tc(i)+l(i,j)+MO(i,j),d(i,j) = u \times T_c(i) + l(i,j) + \frac{M}{O(i,j)},

where Tc(i)T_{c}(i) denotes the time to compute one local update of the model; uu is the number of local updates; l(i,j)l(i,j) is the link latency; MM is the model size; O(i,j)O(i, j) is the total network traffic capacity. However, unlike other communication infrastructures, the multigraph only contains connections between silos without other nodes such as routers or amplifiers. Thus, the total network traffic capacity O(i,j)=min(CUP(i)Ni,CDN(j)Ni+)O(i,j) = \text{min}\left(\frac{C_{\rm{UP}}(i)}{\left|{\mathcal{N}_{i}^{-}}\right|}, \frac{C_{\rm{DN}}(j)}{\left|\mathcal{N}_{i}^{+}\right|}\right) where CUPC_{\rm{UP}} and CDNC_{\rm{DN}} denote the upload and download link capacity. Note that the upload and download processes can happen in parallel.

Since multigraph can contain multiple edges between two nodes, we extend the definition of the delay in the previous equation to dk(i,j)d_k(i,j), with kk is the kk-th communication round during the training process, as:

dk+1(i,j)={dk(i,j),if ek+1(i,j)=1 and ek(i,j)=1max(u×Tc(j),dk(i,j)dk1(i,j)),if ek+1(i,j)=1 and ek(i,j)=0τk(Gm)+dk1(i,j)),if ek+1(i,j)=0 and ek(i,j)=1τk(Gm),otherwised_{k+1}(i,j) = \begin{cases} d_k(i,j), \\\qquad \qquad \text{if } e_{k+1}(i,j) = 1\text{ and }e_{k}(i,j) = 1\\ \text{max}( u \times T_c(j),d_{k}(i,j) - d_{k-1}(i,j)), \\\qquad\qquad\text{if }e_{k+1}(i,j) = 1\text{ and }e_{k}(i,j) = 0\\ \tau_k(\mathcal{G}_m) + d_{k-1}(i,j)), \\\qquad\qquad\text{if } e_{k+1}(i,j) = 0\text{ and }e_{k}(i,j) = 1\\ \tau_k(\mathcal{G}_m), \\\qquad\qquad\text{otherwise} \end{cases}

where e(i,j)e(i,j)==0\mathbb{0} is weakly-connected edge, e(i,j)e(i,j)==1\mathbb{1} is strongly-connected edge; $ \tau_k(\mathcal{G}_m)$ is the cycle time at the kk-th computation round during the training process.

In general, using introduced equation, \textit{the delay of the next communication round dk+1d_{k+1} is updated based on the delay of the previous rounds} and other factors, depending on the edge type connection.

2. Cycle time in multigraph

The cycle time per round is the time required to complete a communication round. In this work, we define the cycle time per round is the maximum delay between all silo pairs with strongly-connected edges. Therefore, the average cycle time of the entire training is:

τ(Gm)=1kk=0k1(maxjNi++{i},iV(dk(j,i))),\tau(\mathcal{G}_m) =\frac{1}{k }\sum^{k-1}_{k=0} \left(\underset{j \in \mathcal{N}^{++}_{i} \cup\{i\}, \forall i \in \mathcal{V}}{\text{max}} \left(d_k\left(j,i\right)\right)\right),

where Ni++\mathcal{N}_{i}^{++} is an in-neighbors silo set of ii whose edges are strongly-connected.

3. Multigraph Construction

Algorithm 1 describes our methods to generate the multigraph Gm\mathcal{G}_m with multiple edges between silos. The algorithm takes the overlay Go\mathcal{G}_o as input. Similar to~\cite{marfoq2020throughput}, we use the Christofides algorithm to obtain the overlay. In Algorithm 1, we establish multiple edges that indicate different statuses (strongly-connected or weakly-connected). To identify the total edges between a silo pair, we divide the delay d(i,j)d(i,j) by the smallest delay dmind_{\min} overall silo pairs, and compare it with the maximum number of edges parameter tt (t=5t=5 in our experiments). \textit{We assume that the silo pairs with longer delay will have more weakly-connected edges, hence potentially becoming isolated nodes}. Overall, we aim to increase the number of weakly-connected edges, which generate more isolated nodes to speed up the training process. Note that, from Algorithm 1, each silo pair in the multigraph should have one strongly-connected edge and multiple weakly-connected edges. The role of the strongly-connected edge is to make sure that two silos have a good connection in at least one communication round.


In the next post, we will mention multigraph parsing proccess and how to train a multigraph under decentralized federated learning.

Reducing Training Time in Cross-Silo Federated Learning using Multigraph Topology (Part 1)

Federated learning is an active research topic since it enables several participants to jointly train a model without sharing local data. Currently, cross-silo federated learning is a popular training setting that utilizes a few hundred reliable data silos with high-speed access links to training a model. While this approach has been widely applied in real-world scenarios, designing a robust topology to reduce the training time remains an open problem. In this paper, we present a new multigraph topology for cross-silo federated learning. We first construct the multigraph using the overlay graph. We then parse this multigraph into different simple graphs with isolated nodes. The existence of isolated nodes allows us to perform model aggregation without waiting for other nodes, hence effectively reducing the training time. Intensive experiments on three public datasets show that our proposed method significantly reduces the training time compared with recent state-of-the-art topologies while maintaining the accuracy of the learned model.

Our code can be found at: https://github.com/aioz-ai/MultigraphFL

1. Introduction

Federated learning involves training models using remote devices or isolated data centers while keeping the data localized to respect user privacy policies. According to available literature, there are two prominent training scenarios: the "cross-device" scenario, which includes numerous unreliable edge devices with limited computational capacity and slow connection speeds, and the "cross-silo" scenario, which features a smaller number of reliable data silos with powerful computing resources and high-speed access links. Recently, the cross-silo scenario has gained traction in various federated learning applications.

In practical terms, federated learning represents a promising research avenue that allows us to harness the capabilities of machine learning techniques while upholding user privacy. Key obstacles in federated learning encompass issues like model convergence, communication bottlenecks, and disparities in data distributions across different silos. A commonly employed federated training approach involves establishing a central node responsible for overseeing the training process and aggregating contributions from all clients. However, a drawback of this client-server approach is the potential for communication bottlenecks, especially when dealing with a large number of clients. To mitigate this limitation, recent research has explored the concept of decentralized or peer-to-peer federated learning, where communication occurs via a peer-to-peer network topology, eliminating the need for a central node. Nevertheless, a major challenge in decentralized federated learning remains achieving rapid training while ensuring model convergence and preserving model accuracy.

In federated learning, the structure of communication networks holds significant importance. Specifically, an efficient network design contributes to quicker convergence, result