58 posts tagged with "ai"

View All Tags

CathAction - A Benchmark for Endovascular Intervention Understanding (Part 3)

1. Tasks and Benchmarks

In this section, we benchmark five tasks, including anticipation, recognition, segmentation, collision detection, and domain adaptation, to demonstrate the usefulness of the CathAction dataset. We then discuss the challenges and opportunities for improvement in each task.

A. Catheterization Anticipation

The anticipation task aims to predict the next catheterization action based on a sequence of frames. We adapt the conventional anticipation task framework in computer vision, introducing two timing parameters: anticipation time (τa\tau_a) and observation time (τo\tau_o). The anticipation time denotes the required duration to recognize an action, while the observation time indicates the length of the video footage to analyze before making a prediction. The objective is to predict the action class ( c_a ) for the frames within the anticipation time τa\tau_a, given the frames during the observation time τo\tau_o.

Network and Training. We leverage state-of-the-art action anticipation methods as baselines: CNN&RNN, RU-LSTM, TempAggRe-Fusion, AFFT, and Trans-SVNet. The future action predictions are supervised using cross-entropy loss with labeled future actions. Following prior works, we set τa=1s\tau_a = 1s and τo=1s\tau_o = 1s. Training was performed on a single Nvidia A100 GPU with a batch size of 64 for 80 epochs, starting with a learning rate of 0.001, reduced by a factor of 10 after epochs 30 to 60. We split approximately 80% of the dataset for training and 20% for testing. Performance metrics include top-1 accuracy, precision, and recall.

BaselineVenuesAccuracyPrecisionRecall
CNNCVPR 201828.9830.1429.76
RNNCVPR 201829.6430.3830.44
RU-LSTMCVPR 201935.0834.2934.77
TempAggReECCV 202034.6435.5634.71
Trans-SVNetIJCARS 202229.0619.6720.28
AFFTWACV 202337.9136.8737.63

Table 1: Catheterization anticipation results on the CathAction dataset. All values are reported in percentages (\%).

Figure 1

Figure 1: Qualitative catheterization prediction results. The predicted and ground truth of the next action are displayed on the right of each sample. The green color shows the correct prediction, and the red color shows the incorrect prediction.

Results. Table 1 shows the catheterization anticipation results of different baselines. This table shows that transformer-based methods show superior performance advantages over CNN or LSTM-based models. Qualitative results are illustrated in Fig 1. We can see that transformer-based models can make more accurate predictions in challenging scenarios, especially when the catheter is moving quickly or when the occlusion is partially obscured.

Discussion. Despite the advancements, existing methods for catheterization anticipation still struggle to achieve high accuracy, revealing areas for future research. The rapid motion of the catheter and guidewire poses significant challenges for this task, and real-time performance is crucial as surgeons require immediate feedback during procedures.

B. Catheterization Recognition

Following the traditional action recognition task in computer vision, in catheterization recognition, given an input video segment, our goal is to predict the action class for that video segment.

Network and Training. We explore state-of-the-art methods in action recognition to benchmark the catheterization recognition task, including TDN, Video Swin Transformer, and BERT Pretraining of Video Transformers (BEVT). Each model is trained using two Nvidia A100 GPUs for 80 epochs, with a mini-batch size of 512. The initial learning rates are set to 0.01 for the spatial stream and 0.001 for the temporal stream, reduced by a factor of 10 at the 20 and 40 epochs. All other parameters are re-used from the baseline methods.

Results. Table 2 show the catheterization recognition results of three baseline methods: TDN, Video Swin Transformer, and BEVT, on the CathAction dataset are summarized in the table below. TDN with ResNet101 achieves the best top-1 accuracy of 62.5% on five classes. Action recognition in endovascular intervention remains challenging due to the similarity in the appearance of catheters and guidewires across different environments, while actions depend on the visual characteristics of the catheters and guidewires.

BaselineVenuesAccuracyPrecisionRecall
TDN-ResNet50CVPR 202158.3459.1257.22
TDN-ResNet101CVPR 202162.5061.8962.77
Video Swin TransformerCVPR 202251.6752.1451.24
BEVTCVPR 202249.2850.2749.92

Table 2: Catheterization recognition results on the CathAction dataset. All values are reported in percentages (\%).

Discussion. Compared to the anticipation task (Table 1), catheterization recognition methods (Table 2) show higher accuracy. However, the overall performance is not yet significant enough for real-world applications. Further research can utilize advanced techniques such as multi-modality learning, combining pre-operative data or synthetic data with transfer learning to improve the results. Additionally, exploring the capabilities of large-scale medical foundation models is an interesting research direction.

C. Catheter and Guidewire Segmentation

Catheter and guidewire segmentation is a well-known task in endovascular interventions. In this task, we aim to segment the catheter and guidewire from the background. Unlike catheterization recognition or anticipation, which take a video as input, this segmentation task only uses the X-ray image as input.

BaselineDice ScoreJaccard IndexmIoUAccuracy
UNet51.6957.5131.1763.26
TransUNet56.5255.9334.1355.61
SwinUNet61.2659.5439.5376.60
SSL56.9556.8740.8772.24
SegViT63.4754.1242.4868.73

Table 3: Segmentation results on the CathAction dataset.

Network and Training. We benchmark U-Net, Trans-UNet, SwinUNet, and SegViT. We follow the default training and testing configurations provided in the published papers. We use the Dice Score, Jaccard Index, mIoU, and Accuracy as the evaluation metrics in the segmentation task.

Results. Table 3 shows the catheter and guidewire segmentation results. This table shows that the transformer-based networks such as TransUNet or SegViT achieve higher accuracy than the traditional UNet. The SegViT that utilizes the vision transformer backbone shows the best performance, however, the increase compared with other methods is not a large margin.

Discussion. In contrast to traditional segmentation tasks in computer vision, which typically involve objects occupying substantial portions of an image, the segmentation of catheters and guidewires presents a considerably greater challenge. These elongated instruments have extremely slender bodies, making their spatial presence in the image less pronounced. Additionally, the unique characteristics of X-ray images can lead to misidentification of catheters or guidewires as blood vessels. Addressing these challenges in future research is imperative to enhance the accuracy of segmentation outcomes.

D. Collision Detection

Detecting the collision of the tip of the catheter or guidewire to the blood vessel wall is an important task in endovascular intervention. We define the collision detection task as an object detection problem. In particular, the tip of the catheter or guidewire of all frames in our dataset is annotated with a bounding box. Each bounding box shares the class of either collision when the tip collides with the blood vessel, or normal when there is no collision with the blood vessel.

Network and Training. We use YOWO, YOWO-Plus, STEP, and HIT. Since the bounding boxes in our ground truth have relatively small sizes, we also explore tiny object detection methods such as Yolov and EFF. The training process starts with a learning rate of 0.0003, which is then decreased by a factor of 10 after 20 epochs, concluding at 80 epochs. We train all methods with a mini-batch size of 4 on an Nvidia A100 GPU. The average precision (AP) and mean average precision (mAP) are used to evaluate the detection results.

BaselineCollisionNormalMeanCollisionNormalMean
APmAP
STEP7.7911.2110.986.9211.299.08
YOWO8.3212.1811.737.4612.289.92
YOWO-Plus8.9212.2311.777.8612.4810.28
HIT9.3712.7412.148.1812.7210.81
Yolov*12.3021.0815.8911.8820.0414.11
EFF*13.7022.1016.9112.1420.7814.88

Table 4: Collision detection results on the CathAction dataset. The symbol (*) denotes tiny object detectors.

Figre 2

Figure 2: Qualitative results for the collision detection task. The first two columns visualize the collision results, the third column visualizes no collision cases, and the last column visualizes a failure case where the tip was not detected.

Results. Table 4 shows the collision detection results. This table indicates that tiny object detectors such as Yolov and EFF achieve higher accuracy compared to other normal object detectors. Furthermore, we observe that the performance of all methods remains relatively low. This highlights the challenges that lie ahead for collision detection in endovascular intervention. Figure 2 shows detection examples where EFF has difficulty detecting collisions between the catheter and the blood vessel.

Discussion. Compared to traditional object detection results on vision datasets, the collision detection results on our dataset are significantly lower, with the top mean AP being only 16.91. The challenges of this task come from two factors. First, the tip of the catheter or guidewire is relatively small in X-ray images. Second, the imbalance between the collision and normal class makes the problem more difficult. Therefore, there is a need to develop special methods to address these difficulties. Future works may rely on attention mechanisms, transformers, or foundation models to develop more sufficient endovascular collision detectors.

E. Domain Adaptation

Our dataset is sourced from two distinct environments: vascular phantom data and animal data. To assess the capacity for learning from phantom data and applying it to real data, we benchmark endovascular interventions under domain adaptation setups. For each task, we train the model on the phantom data and then test it on real animal data. In practice, animal data is similar to human data we capture from humans, and it is much more challenging to perform tasks on real animal or human data.

BaselineVenuesAccuracyPrecisionRecall
RU-LSTMCVPR 201922.9323.9122.57
TempAggReECCV 202017.1618.4118.23
Trans-SVNetIJCARS 202219.0617.6719.58
AFFTWACV 202325.6726.2926.33

Table 5: Catheterization anticipation results under domain adaptation setup. All methods are trained on phantom data and tested on animal data.

Anticipation Adaptation. We use the same methods RU-LSTM, TempAggRe, Trans-SVNet, and AFFT for anticipation adaptation experiments. Table 5 shows the results. Compared with the setup in Table 1, we can see that there is a significant accuracy drop. This highlights the challenges of applying baseline methods in practical real-world scenarios, particularly when dealing with unforeseen situations in catheterization procedures.

BaselineVenuesAccuracyPrecisionRecall
TDN-ResNet50CVPR 202124.1923.1724.56
TDN-ResNet101CVPR 202125.6224.5225.68
Video Swin TransformerCVPR 202228.7927.9828.12
BEVTCVPR 202231.2230.4831.79

Table 6: Catheterization recognition results under domain adaptation setup.

Recognition Adaptation. We repeat the catheterization recognition task under the domain adaptation setup. Table 6 shows the results when all baselines are trained on phantom data and tested on animal data. This table also demonstrates that training under domain adaption setup is very challenging, as compared to Table 2 under normal setting, the accuracy drops approximately 30%30\%.

APmAP
BaselineCollisionNormalMeanCollisionNormalMean
STEP1.532.121.871.091.981.62
YOWO2.124.113.091.973.682.92
YOWO-Plus1.181.431.211.071.261.09
HIT1.311.191.241.061.181.11
Yolov7.318.928.096.287.497.21
EFF8.279.168.197.618.297.88

Table 7: Collision detection results under domain adaptation setup. All methods are trained on phantom data and tested on animal data. The symbol (*) denotes tiny object detectors.

Collision Detection Adaptation. Table 7 shows the results collision detection results under domain adaptation. We can see that under domain adaptation setup, most object detection methods achieve very low accuracy. Therefore, there is an immediate need to improve or design new methods that can detect the collision in real-time for endovascular catheterization procedures.

BaselineDice ScoreJaccard IndexmIoUAccuracy
UNet26.5831.3812.1346.07
TransUNet16.1624.1917.2333.61
SwinUNet17.4138.147.5240.79
SSL26.9132.0418.7242.44
SegViT30.7432.2211.4650.00

Table 8: Domain adaptation segmentation results.

Segmentation Adaptation. Table 8 shows the catheter and guidewire segmentation results when the networks are trained on phantom data and tested on animal data. Similar to other tasks under the domain adaptation setting, we observe a significant accuracy drop in all methods. Overall, SegViT still outperforms other segmentation methods. This shows that the vision transformer backbone may be potentially a good solution for this task.

2. Discussion

We introduce CathAction, a large-scale dataset for endovascular intervention tasks, encompassing annotated ground truth for segmentation, action understanding, and collision detection. While CathAction marks a significant advancement in endovascular interventions, it is important to acknowledge certain limitations. First, despite its comprehensiveness, the dataset may not encompass every possible clinical scenario and could potentially lack representation for rare or outlier cases. Second, our work currently benchmarks vision-based methods, which exhibit insufficient accuracy, and persisting challenges in generalizability and adaptability to real-world scenarios are evident. This is highlighted by the results presented in Section 1 for all catheterization anticipation, recognition, segmentation, and collision detection tasks. Thirdly, we mostly utilize metrics from the vision community to evaluate the results. These metrics may not fully reflect the clinical needs, and the continuous refinement of evaluation metrics and exploration of potential interdependencies among tasks would demand further research.

From our intensive experiments, we see several research directions that benefit from our large-scale datasets: 1. There is an immediate need to develop more advanced methods for catheterization anticipation, recognition, collision detection, and action understanding, especially under domain adaptation setup. Future work can explore the potential of graph neural networks, temporal information, and multimodal or transfer learning to improve the accuracy and reliability of the methods. 2. Currently, we address endovascular intervention tasks independently; future work can combine those tasks and tackle them simultaneously (e.g., the anticipation and collision detection tasks can be jointly trained). This would make the research outputs more useful in clinical practice. 3. Given the fact that CathAction is a large-scale dataset, it can be used to train a foundation model for endovascular interventions or related medical tasks.

3. Conclusion

We introduce CathAction as a large-scale dataset for endovascular intervention research, offering the largest and most comprehensive benchmark to date. With intensive annotated data, CathAction addresses crucial limitations in existing datasets and helps connect computer vision with healthcare tasks. By providing a standardized dataset with public code and public metrics, CathAction promotes transparency, reproducibility, and the collective exploration of different tasks in the field. Our code and dataset are publicly available to encourage further study.

CathAction - A Benchmark for Endovascular Intervention Understanding (Part 2)

1. The CathAction Dataset

This section introduces the CathAction dataset. Specifically, we describe the data collection process and annotation pipeline. We then present statistics regarding different aspects of our large-scale dataset.

Data Collection

Given that endovascular intervention constitutes a medical procedure, acquiring extensive human data is often impractical and time-consuming due to privacy constraints. To address this challenge, we suggest an alternative approach involving the collection of data from two distinct sources:

  1. Utilizing vascular soft silicone phantoms modeled after the human body.
  2. Employing animal subjects, specifically pigs. The selection of pigs is justified by their vascular anatomy, which is widely acknowledged as highly analogous to that of humans.

Ethics
Since our data collection involves experiments with radiation sources (X-ray radiology fluoroscopic systems) and live animals, all relevant ethical approvals were obtained in advance of the collection process. The human subjects who collect the data are well-trained and professional endovascular surgeons, wearing protective suits as part of daily practice in the hospital.

(a) Silicon phantom(b) Data collection setup
Figure-1aFigure-1b

Figure 1:: The human silicon phantom model (a), and data collection setup in the operating room (b).

Phantom Setup
To ensure that data is collected from various models, we use five adult human aortic arch phantoms made of soft silicone, manufactured by Elastrat Ltd., Switzerland. To enhance realism in the interaction between surgical tools and tissues, the phantoms are connected to a pulsatile pump to simulate the flow of normal human blood. All phantoms are placed beneath an X-ray imaging system to mimic a patient lying on an angiography table, preparing for an endovascular procedure.

Animal Setup
We use five live pigs as subjects for data collection. The animal setup is identical to that of a human procedure. During the endovascular intervention, professional surgeons use an iodine-based contrast agent to enhance visibility of specific structures or fluids within the body. Iodine contrast agents are radiopaque, meaning they absorb X-rays, resulting in improved visibility of blood vessels, organs, and other structures such as the catheter and guidewire during imaging.

Figure 2

Figure 2: Example data collected with phantom models (top row) and animals (bottom row). Animal data are more challenging with less visible catheters or guidewires.

Data Collection
Ten skilled professional surgeons are tasked with cannulating three arteries, namely the left subclavian (LSA), left common carotid (LCCA), and right common carotid (RCCA), using a commercial catheter and guidewire. Throughout each catheterization process, the surgeon operator activates the X-ray fluoroscopy using a pedal in the operating room. We developed a real-time image grabber to transmit the video feed of the surgical scene to a workstation. The experiments are conducted under two interventional radiology fluoroscopic systems: Innova 4100 IQ GE Healthcare and EMD Technologies Epsilon X-ray Generator. Fig 1 shows the data collection setup with human silicon phantoms and Fig 2 visualizes the collected data with phantom models and real animals. From 3, we can see that there is a huge domain gap between data collected using phantom models and live animals.

Data Annotation

Actions
Based on advice from expert endovascular surgeons, we define five classes to annotate catheterization actions. These classes fall into three groups: catheter (\texttt{advance catheter} and \texttt{retract catheter}), guidewire (\texttt{advance guidewire} and \texttt{retract guidewire}), and one action involving both the catheter and guidewire (\texttt{rotate}). Surgeons typically rotate both the catheter and guidewire simultaneously, so we use one rotation class. We utilize a free, open-source video editor to annotate the start and end times of each narrated action. All fluoroscopy videos are processed at a 500 x 500 resolution and 24 frames per second (FPS). To ensure annotation quality, all ground-truth actions are manually checked and modified by an experienced endovascular surgeon.

Collision Annotation
In practice, the collision between the catheter (or guidewire) and the blood vessel wall mainly occurs at the instrument's tip. Therefore, for each frame of the fluoroscopy video, we annotate the catheter (or guidewire) tip with a bounding box. There are two classes for the bounding boxes: \texttt{collision} (when the instrument collides with the blood vessel) and \texttt{normal} (when there is no collision). We used an open-source labeling tool to annotate bounding boxes in each video, with all videos encoded at 24 FPS to ensure dataset coherence.

Segmentation
The combination of guidewire and catheter is common in endovascular interventions, where precise navigation through blood vessels is essential for procedure success. Unlike most previous datasets that consider both catheter and guidewire as one class, we manually label catheter and guidewire classes separately in our dataset. Our segmentation ground truth thus provides a more detailed understanding of endovascular interventions.

Dataset Statistics

Overview
As summarized in Table 1 in the previous part 1, CathAction is a large-scale benchmark for endovascular interventions. Our dataset consists of approximately 500,000 annotated frames for action understanding and collision detection, and around 25,000 ground-truth masks for catheter and guidewire segmentation. There are a total of 569 videos in our dataset. Some collected video samples are illustrated in Fig 2. We believe CathAction is currently the largest, most challenging, and most comprehensive dataset of endovascular interventions.

Statistics
The CathAction dataset is annotated with a primary focus on catheters and guidewires. Fig. 3 provides an overview of the distribution of action classes in both animal and phantom data, while Fig. 4 portrays the distribution of action segment lengths, illustrating the substantial variability in segment duration. Additionally, Fig. 5 visually compares the number of bounding boxes between phantom and animal data, revealing a significant disparity between counts of normal and collision boxes, as expected due to the infrequency of collisions in real-world scenarios.

Figure 3

Figure 3: Distribution of the number of action classes in the CathAction dataset. Left-side: Distribution on real animal data. Right-side: Distribution on phantom data.

Figure 4

Figure 4: Duration distribution of segments' actions in the CathAction dataset, on real animal data and phantom data.

Figure 5

Figure 5: Comparison of the number of bounding box objects in real animal data and phantom data.

Adaptation Property
Since data is collected from two sources—phantoms and real animals—a domain gap exists between the two data types. Fig. 2 and Fig. 5 also demonstrate the adaptation property shared between phantom and animal data. This distinctive domain gap renders CathAction a formidable benchmark for evaluating domain adaptation, a critical problem in medical domains where collecting real human data is often infeasible. Using CathAction, we can develop domain adaptation techniques, learning from synthetic or phantom data and effectively applying that knowledge to genuine animal or human data, bridging the gap between controlled simulation and real-world scenarios.

Next

In the next post, we will benchmark our new dataset CathAction on various tasks.

CathAction - A Benchmark for Endovascular Intervention Understanding (Part 1)

Real-time visual feedback from catheterization analysis is crucial for enhancing surgical safety and efficiency during endovascular interventions. However, existing datasets are often limited to specific tasks, small scale, and lack the comprehensive annotations necessary for broader endovascular intervention understanding. To tackle these limitations, we introduce CathAction, a large-scale dataset for catheterization understanding. Our CathAction dataset encompasses approximately 500,000 annotated frames for catheterization action understanding and collision detection, and 25,000 ground truth masks for catheter and guidewire segmentation. For each task, we benchmark recent related works in the field. We further discuss the challenges of endovascular intentions compared to traditional computer vision tasks and point out open research questions. We hope that CathAction will facilitate the development of endovascular intervention understanding methods that can be applied to real-world applications. Intro

1. Introduction

DatasetCollectionType#FramesSourceAnnotationPublicTask
Barbu et al.X-rayVideo535RealManualNoSegmentation
Wu et al.3D EchoVideo800RealManualNoSegmentation
Ambrosini et al.X-rayImage948RealManualNoSegmentation
Mastmeyer et al.3D MRIImage101RealManualNoSegmentation
Yi et al.X-rayImage2,540SynthesisAutomaticNoSegmentation
Nguyen et al.X-rayImage25,271PhantomSemi-AutoNoSegmentation
Danilov et al.3D UltrasoundVideo225SyntheticManualNoSegmentation
Delmas et al.X-rayImage2,357SimulatedAutomaticNoReconstruction
Brost et al.X-rayImage938ClinicalSemi-AutoNoTracking
Ma et al.X-ray, CTImage1,048ClinicalManualNoReconstruction
CathAction (ours)X-rayVideo500,000+Phantom & AnimalManualYesSegmentation, Action Understanding, Collision Detection

Table 1: Endovascular intervention datasets comparison.

Cardiovascular diseases are one of the leading causes of death worldwide. Endovascular intervention has become the gold standard treatment for these diseases, preferred for its advantages over traditional open surgery, including smaller incisions, reduced trauma, and lower risks of comorbidities for patients. Endovascular interventions involve maneuvering small and long medical instruments, such as catheters and guidewires, within the vasculature through small incisions to reach targeted areas for treatment delivery, such as artery stenting, tissue ablation, and drug delivery. However, such tasks require high technical skill, with the primary challenge being to avoid collisions with the vessel wall, which could result in severe consequences, including perforation, hemorrhage, and organ failure. In practice, surgeons rely on 2D X-ray fluoroscopy images to perform these tasks within the 3D human body, which adds a significant challenge in safely controlling the catheter and guidewire.

Recently, learning-based methods for computer-assisted intervention systems have emerged for diverse tasks. Numerous methodologies have been developed to address the challenges of endovascular interventions, including catheter and guidewire segmentation, vision-based force sensing, learning from demonstration, and skill training assistance. Additionally, various deep learning approaches have been proposed for specific tasks in endovascular interventions, such as instrument motion recognition in X-ray sequences, interventionalist hand motion recognition, and collision detection. However, due to challenges in acquiring medical data, most of these methods rely on synthetic data or small, private datasets. Consequently, despite the critical nature of interventions, current methods have not fully capitalized on recent advancements in deep learning, which typically require large-scale training data.

Over the years, several datasets for endovascular intervention have been introduced. Table 1 shows a detailed comparison between current endovascular intervention datasets. However, these datasets share common limitations. First, they are relatively small in terms of the number of images, as collecting real-world medical data is costly. Second, due to privacy challenges in the medical domain, most existing datasets are kept private. Finally, these datasets are often created for a single task, such as segmentation, and do not support other important tasks in endovascular interventions, such as collision detection or action understanding.

Intro

To address these issues, we present CathAction, a large-scale dataset encompassing several endovascular intervention tasks, including segmentation, collision detection, and action understanding. To our knowledge, CathAction represents the largest and most realistic dataset specifically tailored for catheter and guidewire tasks.

In summary, we make the following contributions:

  • We introduce CathAction, a large-scale dataset for endovascular interventions, providing manually labeled ground truth for segmentation, action understanding, and collision detection.
  • We benchmark key tasks in endovascular interventions, including catheterization anticipation, recognition, segmentation, and collision detection.
  • We discuss the challenges and open questions in endovascular intervention. Our code and dataset are publicly available.

2. Related Work

Endovascular Intervention Dataset
Several endovascular intervention datasets have been introduced. Barbu et al. proposed a dataset that effectively localizes the entire guidewire and validated it using a traditional threshold-based method. Other datasets consider fluoroscopy videos at the image level, with mask annotations for each frame from the fluoroscopy videos. For instance, Ambrosini et al. developed a dataset with 948 annotated mask segmentation instances considering both catheter and guidewire as one class. Similarly, Mastmeyer et al. collected and annotated a dataset with 101 segmentation masks for the real catheter from 3D MRI data. More recently, Nguyen et al. proposed a dataset that considers both catheter and guidewire as one class. Overall, most of these datasets have limitations in terms of size, task categories, and focus. To overcome these limitations, we introduce CathAction, a large-scale dataset with various tasks, including catheter and guidewire segmentation, collision detection, and catheter action recognition and anticipation. The CathAction dataset enables the development of more accurate and reliable deep learning methods for endovascular interventions.

Catheterization Action Understanding
Deep learning techniques have demonstrated notable achievements in endovascular intervention action understanding. Jochem et al. presented one of the first works utilizing deep learning for catheter and guidewire activity recognition in fluoroscopy sequences. Subsequently, deep learning-based approaches have gained prominence as the most widely utilized solution for interventionalist hand motion recognition. For instance, Akinyemi et al. introduced a deep learning model based on convolutional neural networks (CNNs) that incorporates convolutional layers for automatic feature extraction and identifies operators' actions. Additionally, Wang et al. proposed a multimodal fusion architecture for recognizing eight common operating behaviors of interventionists. Despite extensive research on deep learning methods for endovascular intervention, it comes with the limitation of medical data: most of these methods use synthetic data or small, private datasets. This leads to the fact that although intervention is a crucial procedure, it has not fully benefited from recent deep learning advancements, where large-scale training data are usually required.

Catheter and Guidewire Segmentation
Catheter and guidewire segmentation is crucial for real-time endovascular interventions. Many methods have been proposed to address the challenges of catheter and guidewire segmentation. The outcomes of catheter and guidewire segmentation can be applied in vision-based force sensing, learning from demonstration, or skill training assistance applications. Traditional methods for catheterization segmentation adopt thresholding-based techniques and do not generalize well on X-ray data. Deep learning methods can learn meaningful features from input data, but they are challenging to apply to catheter segmentation due to the lack of real X-ray data and the tediousness of manual ground truth labeling. Many current learning-based techniques for catheter segmentation and tracking are limited to training on small-scale datasets or synthetic data due to the challenges of large-scale data collection. Our dataset provides manual ground truth labels for both the catheter and guidewire, offering substantial development for catheter and guidewire segmentation.

Collision Detection
Collision detection is a crucial task in endovascular interventions to ensure patient safety. Several attempts have been made to incorporate deep learning models into collision detection, but these methods have focused on identifying risky actions in simulated datasets. While existing methods can be useful for identifying potential hazards, they cannot localize the position of collisions or provide visual feedback. Additionally, these methods have not been widely used in real-world settings due to the lack of annotated bounding boxes for collisions of guidewire tips with vessel walls. Our dataset addresses this limitation by providing annotated bounding boxes for collision events in both phantom and real-world data. This enables the development of deep learning models that can detect collisions in real-time and provide visual or haptic feedback to surgeons.

Next

In the next post, we will describe our new dataset CathAction.

Lightweight Language-driven Grasp Detection using Conditional Consistency Model (Part 3)

Grasping Machine

1. Experiments

Experiment Setup

Dataset. We use the Grasp-Anything dataset in our experiment. Grasp-Anything is a large-scale dataset for language-driven grasp detection with 1M samples. Each image in the dataset is accompanied by one or several prompts describing a general object grasping action or grasping an object at a specific location. Dataset Visualization Evaluation Metrics. Our primary evaluation metric is the success rate, defined similarly to previous works, necessitating an IoU score of the predicted grasp exceeding 25% with the ground truth grasp and an offset angle less than 3030^\circ. We also use the harmonic mean (`H') to measure the overall success rates. All methods' latency (inference time) in seconds is reported using the same NVIDIA A100 GPU.

Comparison with Grasp Detection Methods

BaselineSeenUnseenHLatency
GR-ConvNet0.370.180.240.022
Det-Seg-Refine0.300.150.200.200
GG-CNN0.120.080.100.040
CLIPORT0.360.260.290.131
CLIP-Fusion0.400.290.330.157
MaskGrasp0.500.460.450.116
LLGD (ours) with 1 timestep0.470.340.400.035
LLGD (ours) with 3 timesteps0.520.380.450.106
LLGD (ours) with 10 timesteps0.530.390.460.264

Table 1: Comparision with Traditional Grasp Detection Methods.

We compare our LLGD with GR-CNN, Det-Seg-Refine, GG-CNN, CLIPORT, MaskGrasp, and CLIP-Fusion. Table 1 compares our method and other baselines on the GraspAnything dataset. This table shows that our proposed LLGD outperforms traditional grasp detection methods by a clear margin. Our inference time is also competitive with other methods.

Comparison with Lightweight Diffusion Models

MethodSeenUnseenHLatency
LGD with 3 timesteps0.420.290.350.074
LGD with 30 timesteps0.490.410.450.741
LGD with 1000 timesteps0.520.420.4726.12
SnapFusion with 500 timesteps0.490.370.4312.95
LightGrad with 250 timesteps0.510.340.436.420
LLGD (ours) with 1 timestep0.470.340.400.035
LLGD (ours) with 3 timesteps0.520.380.450.106
LLGD (ours) with 10 timesteps0.530.390.460.264

Table 2: Comparison with Diffusion Models for Language-Driven Grasp Detection.

In this experiment, we compare our LLGD with other diffusion models for language-driven grasp detection. In particular, we compare with LGD using DDPM, and recent lightweight diffusion works: SnapFusion with 500 timesteps and LightGrad with 250 timesteps.

Table 2 shows the result of diffusion models for language-driven grasp detection. We can see that the accuracy and inference time of the classical diffusion model LGD strongly depend on the number of denoising timesteps. LGD with 1000 timesteps achieves reasonable accuracy but has significant long latency. Lightweight diffusion models such as SnapFusion and LightGrad show reasonable results and inference speed. However, our method achieves the highest accuracy with the fastest inference speed.

Conditional Consistency Model Demonstration

Figure 1

Figure 1: Consistency model analysis. With text prompt input "Grasp the cup at its handle", we compare the trajectory grasp pose of our method and LGD. In the figure, the top row illustrates the trajectory of LGD, while the bottom row corresponds to the trajectory of our LLGD.

In this analysis, we will verify the effectiveness of our conditional consistency model. In Figure 1, we visualize the grasp pose aspect to time index tt. In the LGD model, as the discrete diffusion model is employed with T=1000T=1000, we have to perform the diffusion steps with a step size of 1, which results in a very slow inference speed. Moreover, the grasp pose trajectory still exhibits significant fluctuations. Our method can arbitrarily select boundary time points for the continuous consistency model. It is evident that the number of iterations required by our method is significantly less than that of LGD for the exact value of TT, which contributes to the "lightweight" factor. Furthermore, the grasp pose at t=603t=603 has almost converged to the ground truth, while LGD using DDPM at t=350t=350 has not yet achieved a successful grasp.

Ablation Study

Figure 2

Figure 2: Visualization of detection results of different language-driven grasp detection methods.

Visualization. Figure 2 shows qualitative results of our method and other baselines. The outcomes suggest that our method LLGD generates more semantically plausible grasp poses given the same text query than other baselines. In particular, other methods usually show grasp poses at a location not well-aligned with the text query, while our method shows more suitable detection results.*

Figure 3

Figure 3: In the wild detection results. Images are from the internet.

In the Wild Detection. Figure 3 illustrates the outcomes of applying our method to random images from the internet. The results demonstrate that our LLGD can effectively detect the grasp pose given the language instructions on real-world images. Our method showcases a promising zero-shot learning ability, as it successfully interprets grasp actions on images it has never encountered during training.

Figure 4

Figure 4: Prediction failure cases.

Failure Cases. Although promising results have been achieved, our method predicts incorrect grasp poses. Many objects and grasping prompts pose a challenging problem as the network cannot capture all the diverse circumstances that arise in real life. Figure 4 depicts some failure cases where LLGD incorrectly predicts the results, which can be attributed to multiple similar objects that are difficult to distinguish and text prompts that lack detailed descriptions for accurate result determination.

Robotic Experiments

Robotic Setup. Our lightweight language-driven grasp detection pipeline is incorporated within a robotic grasping framework that employs a KUKA LBR iiwa R820 robot to deliver quantifiable outcomes. Utilization of the RealSense D435i camera enables the translation of grasping information from LLGD into a 6DoF grasp posture, bearing resemblance to previous works. Subsequently, a trajectory optimization planner is used to execute the grasping action. Experiments were conducted on a table surface for the single object scenario and the cluttered scene scenario, wherein various objects were placed to test each setup. Table 3 shows the success rate of our method and other baseline models.

Baseline         SingleCluttered
GR-ConvNet + CLIP         0.330.30
Det-Seg-Refine + CLIP         0.300.23
GG-CNN + CLIP         0.100.07
CLIPORT0.270.30
CLIP-Fusion0.400.40
SnapFusion0.400.39
LightGrad0.410.40
LLGD (ours)0.430.42

Table 3: Robotic language-driven grasp detection results.

Our method outperforms other baselines in both single object and cluttered scenarios. Furthermore, our lightweight model allows rapid execution speed without sacrificing the accuracy of visual grasp detection.

2. Discussion

Limitation. Despite achieving notable results in real-time applications, our method still has limitations and predicts incorrect grasp poses in challenging real-world images. Faulty grasp poses are often due to the correlation between the text and the attention map of the visual features not being well-aligned as Figure 4. From our experiment, we see that when grasp instruction sentences contain rare and challenging nouns that are popular in the dataset, ambiguity in parsing or text prompts occurs, which is usually the main cause of incorrect predictions of grasp poses. Therefore, providing the instruction prompts with clear meanings is essential for the robot to understand and execute the correct grasping action.

Future work. We see several prospects for improvement in future work: 1. Expanding our method to handle 3D space is essential, implementing it for 3D point clouds and RGB-D images to avoid the lack of depth information in robotic applications. 2. Addressing the gap between the semantic concept of text prompts and input images, analyzing the detailed geometry of objects to effectively distinguish between items with similar structures. 3. Expanding the problem to more complex language-driven manipulation applications. For instance, if the robot wants to grasp a plate containing apples, it would need to manipulate the objects in such a manner that prevents the apples from falling.

Lightweight Language-driven Grasp Detection using Conditional Consistency Model (Part 2)

Grasping Machine

1. Lightweight Language-driven Grasp Detection

Overview

Given an input RGB image and a text prompt describing the object of interest, we aim to detect the grasping pose on the image that best matches the text prompt input. We follow the popular rectangle grasp convention widely used in previous work to define the grasp.

In the diffusion model, we represent the target grasp pose as x0\mathbf{x}_0. The objective of our diffusion process of language-driven grasp detection involves denoising from a noisy state xT\mathbf{x}_T to the original grasp pose x0\mathbf{x}_0, conditioned on the input image and grasp instruction represented by yy. The forward process in traditional conditional diffusion model is defined as:

q(xtxt1)=N(1βtxt1,βtI) ,(1)q(\mathbf{x}_t|\mathbf{x}_{t-1}) = \mathcal{N}(\sqrt{1-\beta_t}\mathbf{x}_{t-1},\beta_t\mathbf{I})~, \tag{1}

where the hyperparameter βₜ is the amount of noise added at diffusion step t ∈ [0,T] ⊆ ℝ.

To train a diffusion model with condition y, we use a neural network to learn the reverse process:

pϕ(xt1xt,y)=N(μϕ(xt,t,y),Σϕ(xt,t,y)) .(2)p_\phi(\mathbf{x}_{t-1}|\mathbf{x}_t,y) = \mathcal{N}(\mu_\phi(\mathbf{x}_t,t,y),\Sigma_\phi(\mathbf{x}_t,t,y))~. \tag{2}

In our approach, we utilize the diffusion process in the continuous domain, where xt\mathbf{x}_t is the grasp pose state at arbitrary time index tt. Unlike popular discrete diffusion models, by using a continuous space, we can improve sample quality and reduce inference times due to the ability to traverse the diffusion process at arbitrary timesteps, allowing for more fine-grained control over the denoising process.

Method Overview

Figure 1: The overview of our method. First, the input RGB image and text prompt are fed into the feature encoder and ALBEF fusion. Subsequently, we concurrently train two models with the same architectures: A score network to estimate the probability flow Ordinary Differential Equation (ODE) trajectory for the diffusion process and a conditional consistency model to determine the grasp pose with a few denoising steps.

Conditional Consistency Model for LLGD

To reduce the inference time during the denoising step of the diffusion model, we aim to estimate the original grasp pose with just a few denoising steps. Since our language-driven grasp detection task has the condition yy, we introduce a conditional consistency model based on the consistency concept to infer the original grasp pose during the inference process directly:

fθ(xt,t,y)={xtt[0,ϵ]Fθ(xt,t,y)t(ϵ,T] ,(3)\mathbf{f}_\theta(\mathbf{x}_t,t,y) = \begin{cases} \mathbf{x}_t & t \in [0,\epsilon] \\ \mathbf{F}_\theta(\mathbf{x}_t,t,y) & t \in (\epsilon,T] \end{cases}~, \tag{3}

where fθ(xϵ,t,y)=xϵ\mathbf{f}_\theta(\mathbf{x}_\epsilon, t, y) = \mathbf{x}_\epsilon is the boundary condition, and Fθ(xt,t,y)\mathbf{F}_\theta(\mathbf{x}_t,t,y) is a free-form deep neural network whose output has the same dimensionality as xt\mathbf{x}_t.

To train our conditional consistency model, we employ knowledge distillation from a continuous diffusion process:

dxt=12γtxtdt+γtdwt ,(4)d\mathbf{x}_{t} = -\frac{1}{2}\gamma_t\mathbf{x}_t dt + \sqrt{\gamma_t} d\mathbf{w}_t~, \tag{4}

where γt\gamma_t is a non-negative function referred to as the noise schedule, and wt\mathbf{w}_t is the standard Brownian motion. This forward process creates a trajectory of grasp poses {xt}t=0T\{\mathbf{x}_t\}_{t=0}^T. The grasp pose state xt\mathbf{x}_t depends on the time index tt and the input image and text prompt. The grasp distribution p(x0y)p(\mathbf{x}_0|y) from the dataset is transformed into p(xTy)N(0,I)p(\mathbf{x}_T|y) \sim \mathcal{N}(0, \mathbf{I}). Given the ground truth grasp pose x0\mathbf{x}_0, we can sample xt\mathbf{x}_t at arbitrary tt:

p(xtx0)=N(μt,Σt) ,(5)p(\mathbf{x}_t|\mathbf{x}_0) = \mathcal{N}(\mu_t, \Sigma_t)~, \tag{5}

where

μt=e12ρtx0,Σt=(1eρt)I,ρt=0tγsds .\mu_t = e^{\frac{1}{2}\rho_t} \mathbf{x}_0, \Sigma_t = (1 - e^{\rho_t})\mathbf{I}, \rho_t = -\int_{0}^{t} \gamma_s ds~.

The equation (4) is a probability flow ODE. With the conditional variable yy, it can be redefined as:

dxtdt=12γt[xt+logp(xty)] ,(6)\frac{d\mathbf{x}_t}{dt} = -\frac{1}{2}\gamma_t\left[\mathbf{x}_t + \nabla\log p(\mathbf{x}_t|y)\right]~, \tag{6}

where logp(xty)\nabla\log p(\mathbf{x}_t|y) is the score function of the conditional diffusion model.

Suppose that we have a neural network sϕ(xt,t,y)\mathbf{s}_\phi(\mathbf{x}_t, t, y) that can approximate the score function logp(xty)\nabla\log p(\mathbf{x}_t|y), i.e., sϕ(xt,t,y)logp(xty)\mathbf{s}_\phi(\mathbf{x}_t, t, y) \approx \nabla\log p(\mathbf{x}_t|y). After training the score network, we can replace the logp(xty)\nabla\log p(\mathbf{x}_t|y) term in the equation (6) with a neural network:

dxtdt=12γt[xt+sϕ(xt,t,y)] .(7)\frac{d\mathbf{x}_t}{dt} = -\frac{1}{2}\gamma_t\left[\mathbf{x}_t + \mathbf{s}_\phi(\mathbf{x}_t, t, y)\right]~. \tag{7}

Score Function Loss. In order to approximate the score function logp(xty)\nabla\log p(\mathbf{x}_t|y), the conditional denoising estimator minimizes the following objective:

Lscore=EtU[0,T]x0,yp(x0,y)xtp(xtx0)[λ(t)logp(xtx0)sϕ(xt,t,y)2] ,(8)\mathcal{L}_{\rm score}=\mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ \mathbf{x}_0,y \sim p(\mathbf{x}_0,y) \\ \mathbf{x}_t \sim p(\mathbf{x}_t|\mathbf{x}_0) \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|\mathbf{x}_0) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right]~, \tag{8}

where λ(t)R+\lambda(t) \in \mathbb{R}^+ is a positive weighting function.

Proposition 1. Suppose that xt\mathbf{x}_t is conditionally independent of yy given x0\mathbf{x}_0, then minimizing Lscore\mathcal{L}_{\rm score} is the same as minimizing:

EtU[0,T]xt,yp(xt,y)[λ(t)logp(xty)sϕ(xt,t,y)2] .\mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ \mathbf{x}_t,y \sim p(\mathbf{x}_t,y) \\ \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|y) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right]~.

Proof. Because xt\mathbf{x}_t is conditionally independent of yy given x0\mathbf{x}_0, we have:

EtU[0,T]x0,yp(x0,y)xtp(xtx0)[λ(t)logp(xtx0)sϕ(xt,t,y)2]=EtU[0,T]yp(y)x0p(x0y)xtp(xtx0)[λ(t)logp(xtx0)sϕ(xt,t,y)2]=EtU[0,T]yp(y)x0p(x0y)xtp(xtx0,y)[λ(t)logp(xtx0,y)sϕ(xt,t,y)2]=EtU[0,T]yp(y)[Φ(t,y)] ,(9)\begin{aligned} &\mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ \mathbf{x}_0,y \sim p(\mathbf{x}_0,y) \\ \mathbf{x}_t \sim p(\mathbf{x}_t|\mathbf{x}_0) \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|\mathbf{x}_0) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right] \\ &= \mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ y \sim p(y) \\ \mathbf{x}_0 \sim p(\mathbf{x}_0|y)\\ \mathbf{x}_t \sim p(\mathbf{x}_t|\mathbf{x}_0) \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|\mathbf{x}_0) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right] \\ &= \mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ y \sim p(y) \\ \mathbf{x}_0 \sim p(\mathbf{x}_0|y)\\ \mathbf{x}_t \sim p(\mathbf{x}_t|\mathbf{x}_0,y) \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|\mathbf{x}_0,y) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right] \\ &= \mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ y \sim p(y) \\ \end{subarray} }\left[\Phi(t,y)\right]~, \tag{9} \end{aligned}

where

Φ(t,y)=Ex0p(x0y)xtp(xtx0,y)[λ(t)logp(xtx0,y)sϕ(xt,t,y)2] .\begin{aligned} &\Phi(t,y)\\ &=\mathbb{E}_{ \begin{subarray}{l} \mathbf{x}_0 \sim p(\mathbf{x}_0|y)\\ \mathbf{x}_t \sim p(\mathbf{x}_t|\mathbf{x}_0,y) \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|\mathbf{x}_0,y) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right]~. \end{aligned}

If yy and tt are fixed, we can define a transition probability that does not depend on these variables, q(x0)=p(x0y)q(\mathbf{x}_0) = p(\mathbf{x}_0|y), κ(xt)=sϕ(xt,t,y)\kappa(\mathbf{x}_t)=\mathbf{s}_\phi(\mathbf{x}_t,t,y). According to Vincent P., 2011, we have:

Φ(t,y)=Ex0q(x0)xtq(xtx0)[λ(t)logq(xtx0)κ(xt)2]=E(x0,xt)q(x0,xt)[λ(t)logq(xtx0)κ(xt)2]=Extq(xt)[λ(t)logq(xt)κ(xt)2]=Extp(xty)[λ(t)logp(xty)sϕ(xt,t,y)2] .(10)\begin{aligned} \Phi(t,y) &= \mathbb{E}_{ \begin{subarray}{l} \mathbf{x}_0 \sim q(\mathbf{x}_0)\\ \mathbf{x}_t \sim q(\mathbf{x}_t|\mathbf{x}_0) \end{subarray} }\left[\lambda(t) \|\nabla\log q(\mathbf{x}_t|\mathbf{x}_0) - \kappa(\mathbf{x}_t)\|^2 \right] \\ &= \mathbb{E}_{ \begin{subarray}{l} (\mathbf{x}_0,\mathbf{x}_t) \sim q(\mathbf{x}_0,\mathbf{x}_t)\\ \end{subarray} }\left[\lambda(t) \|\nabla\log q(\mathbf{x}_t|\mathbf{x}_0) - \kappa(\mathbf{x}_t)\|^2 \right] \\ &= \mathbb{E}_{ \begin{subarray}{l} \mathbf{x}_t \sim q(\mathbf{x}_t)\\ \end{subarray} }\left[\lambda(t) \|\nabla\log q(\mathbf{x}_t) - \kappa(\mathbf{x}_t)\|^2 \right] \\ &= \mathbb{E}_{ \begin{subarray}{l} \mathbf{x}_t \sim p(\mathbf{x}_t|y)\\ \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|y) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right]~. \tag{10} \end{aligned}

From the equations (9) and (10), we can prove the equivalence of the two objective functions.

EtU[0,T]x0,yp(x0,y)xtp(xtx0)[λ(t)logp(xtx0)sϕ(xt,t,y)2]=EtU[0,T]yp(y)xtp(xty)[λ(t)logp(xty)sϕ(xt,t,y)2]=EtU[0,T](xt,y)p(xt,y)[λ(t)logp(xty)sϕ(xt,t,y)2] .(11)\begin{aligned} &\mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ \mathbf{x}_0,y \sim p(\mathbf{x}_0,y) \\ \mathbf{x}_t \sim p(\mathbf{x}_t|\mathbf{x}_0) \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|\mathbf{x}_0) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right] \\ =& \mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ y \sim p(y) \\ \mathbf{x}_t \sim p(\mathbf{x}_t|y) \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|y) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right] \\ =& \mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ (\mathbf{x}_t,y) \sim p(\mathbf{x}_t,y) \\ \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|y) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right]~. \tag{11} \end{aligned}

Discretization. Consider discretizing the time horizon [ϵ,T][\epsilon,T] into N1N-1 with boundary t1=ϵ<t2<t3<<tN=Tt_1=\epsilon<t_2<t_3<\ldots<t_{N}=T. If NN is sufficiently large, we can use an ODE-solver to estimate the next discretization step:

x^ti=xti+1+(titi+1)dxdtt=ti+1\hat{\mathbf{x}}_{t_i} = \mathbf{x}_{t_{i+1}} + (t_i - t_{i+1}) \left. \frac{d\mathbf{x}}{dt} \right|_{t = t_{i+1}}
=xti+112γi+1(titi+1)[xti+1+sϕ(xt,t,y)] .(12)= \mathbf{x}_{t_{i+1}} - \frac{1}{2}\gamma_{i+1} (t_i - t_{i+1})\left[\mathbf{x}_{t_{i+1}} + \mathbf{s}_\phi(\mathbf{x}_t,t,y)\right]~. \tag{12}

Conditional Consistency Model Loss. To enable fast sampling, we expect that the predicted point x^ti\hat{\mathbf{x}}_{t_i} and xti+1\mathbf{x}_{t_{i+1}} to lie on the same probability flow ODE trajectory. We propose conditional consistency loss to enforce this constraint:

Lconsistency=EiU[1,N1]xti+1p(xti+1x0)[λ(ti)fθ(xti+1,ti+1,y)fθ(x^ti,ti,y)2] ,(13)\mathcal{L}_{\rm consistency} = \mathbb{E}_{ \begin{subarray}{l} i \sim \mathcal{U}[1, N - 1] \\ \mathbf{x}_{t_{i+1}} \sim p(\mathbf{x}_{t_{i+1}}|\mathbf{x}_0) \end{subarray} } \left[\lambda(t_i) \|\mathbf{f}_\theta(\mathbf{x}_{t_{i+1}},t_{i+1},y) - \mathbf{f}_{\theta^*}(\hat{\mathbf{x}}_{t_{i}},t_{i},y)\|^2 \right]~, \tag{13}

where x^ti\hat{\mathbf{x}}_{t_i} is calculated in Equation 12, xti+1\mathbf{x}_{t_{i+1}} is sampling from Gaussian distribution in Equation 5, and θ\theta is the parameters of neural network f\mathbf{f}.

Additionally, we need to minimize the discrepancy between the predicted and ground truth grasp poses with the detection loss:

Ldetection=EiU[1,N]xtiN(μti,Σti)x0,yp(x0,y)[λ(ti)fθ(xti,ti,y)x02] .(14)\mathcal{L}_{\rm detection} = \mathbb{E}_{ \begin{subarray}{l} i \sim \mathcal{U}[1, N] \\ \mathbf{x}_{t_{i}} \sim \mathcal{N}(\mu_{t_{i}},\Sigma_{t_{i}}) \\ \mathbf{x}_0,y \sim p(\mathbf{x}_0,y) \end{subarray} }\left[\lambda(t_i)\|\mathbf{f}_\theta(\mathbf{x}_{t_i}, t_i, y) - \mathbf{x}_0\|^2\right]~. \tag{14}

The overall training objective for our method is:

Ltotal=Lconsistency+Ldetection .(15)\mathcal{L}_{\rm total} = \mathcal{L}_{\rm consistency} + \mathcal{L}_{\rm detection}~. \tag{15}

Network Details

The input of our network is the image and a corresponding grasping text prompt represented as ee (for example, "grasp the fork at its handle"). We first extract the image feature using a 12-layer vision transformer ViT image encoder. The input text prompt is encoded by a text encoder using BERT or CLIP. We then combine and learn the features of the input text prompt and input image using the ALBEF fusion network. The output of the fusion features is fed into a score network, and our conditional consistency model is used to learn the grasp pose. Figure 1 shows the detail of our network.

Score Network. In practice, we utilize a score network composed of several MLP layers to extract three components: the noisy grasp pose xt\mathbf{x}_t, the time index tt, and the conditional vision-language embedding yy. Subsequently, these features are concatenated, and the score function is extracted through a final MLP layer. It is crucial to ensure that the output dimension of the scoring network is identical to the dimension of the input xt\mathbf{x}_t because, fundamentally, the score function is the gradient of the grasp pose distribution given the condition yy. Our conditional consistency model's network has an architecture similar to the scoring network; however, its output is the predicted grasp pose.

Algorithm 1: Inference Process

Input: Image and text prompt, conditional consistency model fθ(x,t,y)\mathbf{f}_\theta(\mathbf{x},t,y), number of inference steps PP, sequence of time points t1=ϵ<t2<t3<<tP=Tt_1 = \epsilon < t_2 < t_3 < \dots < t_{P} = T, noise scheduler αt=eρt\alpha_t = e^{\rho_t}.

yALBEF (image, prompt)y \gets \text{ALBEF (image, prompt)}

Initial grasp noise xTN(0,I)\mathbf{x}_T \sim \mathcal{N}(0,\mathbf{I})

x0fθ(xT,T,y)\mathbf{x}_0 \gets \mathbf{f}_\theta(\mathbf{x}_T,T,y)

For i=P1i = P - 1 to 22:

  • Sample zN(0,I)\mathbf{z} \sim \mathcal{N}(0,\mathbf{I})
  • xtiαtix0+1αtiz\mathbf{x}_{t_i} \gets \sqrt{\alpha_{t_i}}\mathbf{x}_0 + \sqrt{1 - \alpha_{t_i}}\mathbf{z}
  • x0fθ(xti,ti,y)\mathbf{x}_0 \gets \mathbf{f}_\theta(\mathbf{x}_{t_i},t_i,y)

Output: Final grasp pose x0\mathbf{x}_0

Training and Inference

During training, we freeze the text and image encoder, then train the ALBEF fusion, the scoring network, and the consistency model end-to-end. The score network and the conditional consistency model share the same architecture. We trained both models simultaneously for 1000 epochs with a batch size of 8 using the Adam optimizer. The training time takes approximately three days on an NVIDIA A100 GPU. Regarding the parameters of the conditional consistency model, we empirically set T=1000T = 1000, ϵ=1\epsilon = 1, and N=2000N = 2000. After training the scoring network and the conditional consistency model fθ(xt,t,y)\mathbf{f}_\theta(\mathbf{x}_t,t,y), we can sample the grasp pose given the input image and language instruction prompt in a few denoising steps using our algorithm 1.

Method Overview Figure 2: Robot Hands with different ultilities.

Next

In the next post, we will evaluate the effectiveness of our proposal.

Guide3D A Bi-planar X-ray Dataset for 3D Shape Reconstruction (Part 4)

We evaluate our proposed dataset, Guide3D, through a structured experimental analysis, as follows: i) initially, we assess the dataset’s validity, focusing on reprojection errors and their distribution across the dataset to understand its accuracy; ii) we then explore the applicability of Guide3D in a 3D reconstruction task ; and iii) finally, we benchmark several segmentation algorithms to assess their performance on Guide3D, providing insights into the dataset’s utility.

Dataset Validation

Our analysis revealed a non-uniform distribution of reprojection errors across the dataset, with the highest variability and errors concentrated at the proximal end of the guidewire reconstructions. Figure 1 shows the reprojection error patterns for both Camera A and Camera B. For Camera A, mean errors increase from approximately 6 px to a peak of 20 px, with standard deviations rising from 5 px to 11 px, indicating growing inaccuracies and variability over time. Significant fluctuations around indices 25 to 27 highlight periods of particularly high error. For Camera B, mean errors exhibit an initial peak of 9 px at index 1, followed by fluctuations that decrease towards the end. The standard deviations for Camera B start high at 11 px and decrease over time, reflecting a pattern of high initial variability that stabilizes later. These patterns are consistent with the inherent flexibility of the guidewire, which can form complex shapes such as loops.

Figure 1. Guidewire Reconstruction Error Analysis: (Left) Illustrates the distribution of reprojection errors, noting higher variability and peak errors in the mid-sections and reduced errors at the extremities. (Right) Presents the results of reconstruction validation.

Furthermore, we conducted a validation procedure using CathSim, incorporating the aortic arch model described in next subsection and a guidewire of similar diameter and properties. For sampling, we employed the soft actor-critic (SAC) algorithm with segmented guidewires and kinematic data, producing realistic validation samples. Evaluation metrics included maximum Euclidean distance (MaxED) at 2.880 ± 0.640 mm, mean error in tip tracking (METE) at 1.527 ± 0.877 mm, and mean error related to the robot’s shape (MERS) at 0.001 ± 0.000. These results demonstrate the method’s precision.

Guidewire Prediction Results

We now demonstrate the capability of the introduced network and highlight the importance of the proposed dataset. We examine the network prediction in the following manner: 1) we first conduct an analysis between the predicted and reconstructed curve by employing piecewise metrics, and 2) we showcase the reprojection error.

Shape Prediction Errors: Table 1 presents the comparison of different metrics for shape prediction accuracy. We quantify the shape differences using the following metrics: 1) Maximum Euclidean Distance (MaxED), 2) Mean Error in Tip Tracking (METE), and 3) Mean Error in Robot Shape (MERS). For all the metrics, the shape of the guidewire, represented as a 3D curve C(u)\mathbf{C}(u), is sampled at equidistant Δu\Delta u intervals along the arclength parameter uu. Therefore, the metrics represent the pointwise discrepancies between the two shapes along the curve’s arclength.

The results indicate that the spherical representation consistently outperforms the Cartesian representation across all metrics. Specifically, the Maximum Euclidean Distance (MaxED) shows a lower error in the spherical representation (6.88 ± 5.23 mm) compared to the Cartesian representation (10.00 ± 4.64 mm). Similarly, the Mean Error in Tip Tracking (METE) is significantly lower in the spherical representation (3.28 ± 2.59 mm) than in the Cartesian representation (6.93 ± 3.94 mm). For the Mean Error in Robot Shape (MERS), the spherical representation also demonstrates a reduced error (4.54 ± 3.67 mm) compared to the Cartesian representation (5.33 ± 2.73 mm). Lastly, the Fréchet distance shows a smaller error for the spherical representation (6.70 ± 5.16 mm) compared to the Cartesian representation (8.95 ± 4.37 mm). These results highlight the advantage of using the spherical representation for more accurate shape prediction.

Table 1 Shape Comparison (mm).

Shape Comparison Visualization: Figure 2a showcases two 3D plots from different angles, comparing the ground truth guidewire shape to the predicted shape by the network. The network demonstrates its capability to accurately predict the guidewire shape, even in the presence of a loop and self-obstruction in the image. The predicted shape aligns closely with the actual configuration of the guidewire. Notably, the proximal end manifests a more substantial error relative to the nominal error seen at the distal end. Discrepancies from the authentic guidewire shape span from a mere 2 mm at the distal end to a noticeable 5 mm at the proximal end. Impressively, the network evidences its capability to accurately predict the guidewire’s shape using only consecutive singular plane images. Subsequently, the 3D points are reprojected onto the original images, as illustrated in Figure 2b.

Figure 2. The figure illustrates the reconstruction similarity of the guidewire when reprojected onto the images. It demonstrates the network’s capability to accurately predict the guidewire shape, even in the presence of noticeable angles, highlighting the robustness of the prediction model.

Segmentation Results

We demonstrate Guide3D’s potential to advance guidewire segmentation research by evaluating the performance of three state-of-the-art network architectures: UNet (learning rate: 1×1051 \times 10^{-5}, 135 epochs), TransUnet (integrating ResNet50 and Vision Transformer (ViT-B-16), learning rate: 0.01, 199 epochs), and SwinUnet (Swin Transformer architecture, learning rate: 0.01, 299 epochs). Performance metrics included the Dice coefficient (DiceM), mean Intersection over Union (mIoU), and Jaccard index, detailed in Table 2. The results indicate that UNet achieved a DiceM of 92.25, mIoU of 36.60, and Jaccard index of 86.57. TransUnet outperformed with a DiceM of 95.06, mIoU of 41.20, and Jaccard index of 91.10. SwinUnet recorded a DiceM of 93.73, mIoU of 38.58, and Jaccard index of 88.55. These findings benchmark the dataset’s performance and suggest potential for future enhancements. Despite these promising results, the presence of loops and occlusions within the guidewire indicates that polyline prediction could significantly improve task utility.

Table 2 Segmentation Results.

Discussion and Conclussion

This paper introduces a new dataset, Guide3D, for segmentation and 3D reconstruction of flexible, curved endovascular tools. Extensive experiments demonstrate the dataset’s value; yet several limitations must be acknowledged. Firstly, our dataset lacks clinical real human data due to the complexity and regulatory challenges of acquiring such data. Our standardized platform, however, aims to enable further research, providing a stepping stone towards clinical practice.

Additionally, the dataset primarily focuses on synthetic and experimental scenarios, which may not fully capture the variability and unpredictability of real-world clinical environments. While this controlled setting aids initial algorithm development and benchmarking, further validation with clinical data is necessary to ensure the robustness and generalizability of the proposed methods.

Moreover, the guidewire’s flexibility and the presence of loops and occlusions present significant challenges for segmentation and reconstruction tasks. Our dataset includes these complexities to push the boundaries of current methodologies, but future work should explore more advanced techniques.

Our dataset accommodates both video and image-based approaches, providing a versatile resource to facilitate the translation of these technologies into clinical settings. Our objective is to bridge the disparity between research developments and clinical application by establishing a standardized framework for evaluating the efficacy of various methodologies. Our code and dataset will be made publicly available.

Lightweight Language-driven Grasp Detection using Conditional Consistency Model (Part 1)

Language-driven grasp detection is a fundamental yet challenging task in robotics with various industrial applications. This work presents a new approach for language-driven grasp detection that leverages lightweight diffusion models to achieve fast inference time. By integrating diffusion processes with grasping prompts in natural language, our method can effectively encode visual and textual information, enabling more accurate and versatile grasp positioning that aligns well with the text query. To overcome the long inference time problem in diffusion models, we leverage the image and text features as the condition in the consistency model to reduce the number of denoising timesteps during inference. The intensive experimental results show that our method outperforms other recent grasp detection methods and lightweight diffusion models by a clear margin. We further validate our method in real-world robotic experiments to demonstrate its fast inference time capability.

Grasping Machine

1. Introduction

Grasping is one of the fundamental tasks in robotics, enabling robots to interact with the physical world through a broad spectrum of applications, from industrial automation and human-robot interaction to service robotics. Recent advancements in machine vision have significantly improved the capabilities of grasp detection for the robot. Prior research has demonstrated encouraging grasp detection results in both 2D images and 3D point clouds. However, most existing works define grasp detection as a region localization problem while ignoring the use of natural language to localize possible grasps on the object based on linguistic input.

Method Overview Figure 1: Virtual Demonstration of grasping a commanded object.

With the recent advances in Large Language Models (LLM), integrating language into robotic systems has become more popular. Pretrained models such as ChatGPT and CLIP have revolutionized various applications, and their adaptability to the robotic domain has shown encouraging results. Although several language-driven robotic manipulations work, most focus on understanding high-level actions and overlook the fundamental grasping task. In this paper, we tackle the language-driven grasp detection task that allows the robot to grasp specific objects based on the language command. With language-driven grasping ability, robots can interact more effectively with the surrounding environment and humans.

Language-driven grasping offers several advantages compared to the traditional grasp detection task without text. Firstly, we communicate with robots by providing language prompts that direct them to execute precise tasks; therefore, the incorporation of natural language instructions augments robotic systems with the ability to respond to dynamic, real-time tasks interactively. Secondly, using natural language addresses the challenge of ambiguity in identifying target objects within cluttered environments or distinguishing among objects with similar shapes. Lastly, linguistic guidance enriches robotic systems with semantic information, enhancing their learning capabilities without necessitating expert demonstrations or specific engineering.

Several works on grasp detection have recently utilized diffusion models as the essential technique and shown encouraging results. This is motivated by the proven efficacy of diffusion models in conditional generation tasks such as image synthesis, image segmentation, and visual grounding. The effectiveness of diffusion models comes from their iterative approach to gradually refine data from an initial state of pure noise toward a meaningful output. Nonetheless, applying diffusion models to language-driven tasks in robotics faces a key challenge, i.e., the inference time of diffusion models is usually not fast enough for real-time robotic applications. Consequently, recent studies have introduced techniques to tackle the inference speed problem of diffusion models using approaches such as rapid sampling, knowledge distillation, or model optimization. However, these models can still not perform fast sampling with language conditions during inference to meet the real-time requirement in robotic grasping.

In this paper, we propose a new lightweight diffusion model to tackle the inference speed problem in utilizing the diffusion model for the language-driven grasp detection task. To this end, we exploit the capabilities of flow-based generative models to improve the precision of robots in identifying grasp poses from textual inputs. In particular, we develop a conditional consistency model for fast inference speed for real-time robotic applications. We verify our proposed method on a recent large-scale language-driven grasping dataset and achieve superior accuracy and inference speed compared with recent approaches. Furthermore, our method enables zero-shot learning and generalizes it to real-world robotic grasping applications.

Our contributions are summarized as follows:

  • We present Lightweight Language-driven Grasp Detection (LLGD), a fast diffusion model for language-driven grasp detection.
  • We conduct intensive analysis to validate our method and demonstrate that it outperforms other approaches in terms of both accuracy and execution speed.

2. Related Works

Grasp Detection. Grasp detection has been a central topic in robotics, aiming to equip robots with the ability to identify and execute object grasping in complex environments. Several works have set the foundation for robot grasping by using convolutional neural networks (CNNs). Most previous grasp detection methods are often limited to simple tasks with a fixed number of classes and rely solely on raw image data. Several works have extended the problem by using RGB-D images or 3D point clouds to output the results in 3D space. However, they still have not focused on integrating language as the input instruction in the grasp detection problem.

Language-driven Grasping. Language-driven grasp detection introduces the use of natural language to inform grasp detection tasks. The standard approach to tackling the task of language-driven grasp detection is to divide it into a two-step process. One stage identifies the target object, and the second focuses on generating grasp poses based on the established visual-text correlations. Foundation models such as GroundDINO and CLIP have emerged, enabling zero-shot detection and segmentation. These models allow for the localization of the target object without training. However, due to their large size, they result in longer inference times. Accessing such commercial foundation models is not always possible, especially since LLM models often require using APIs, which come at a high cost.

Lightweight Diffusion Model. Lightweight diffusion models that maintain performance while reducing computational overhead have become crucial in machine learning. Researchers have utilized knowledge distillation for low-resolution features to reduce the number of parameters in U-Net. Recently, consistency models have surfaced as a robust approach of generative models capable of producing high-quality images within a single or a limited number of steps. Although there are significant applications in generative tasks, these models are primarily unconditional. On the other hand, robotic applications remain discriminative, making the use of unconditional diffusion models not entirely suitable. In this study, we address this issue by building a lightweight diffusion model with language conditions. We aim to enhance the consistency model work to inherit its fast inference time while adding the language conditions to make it more suitable for the language-driven grasping task.

Method Overview Figure 2: GraspNet Dataset, a widely used data for Grasp Detection.

Next

In the next post, we will introduce our proposal Lightweight Language-driven Grasp Detection using Conditional Consistency Model.

Guide3D A Bi-planar X-ray Dataset for 3D Shape Reconstruction (Part 3)

Utilizing the Guide3D dataset, we build a benchmark for the shape prediction task, a critical component in endovascular intervention. Accurate shape prediction of the guidewire is essential for successful navigation and intervention. Here, we introduce a novel shape prediction network designed to predict the guidewire shape from a sequence of monoplanar images. This approach leverages deep learning to learn spatio-temporal correlations from a static camera observing a dynamic scene. Unlike conventional reconstruction methods that require biplanar images, our network uses a sequence of images to extract temporal information, allowing it to map a single image IA\mathbf{I}_A to the 3D guidewire curve C(u)\mathbf{C}(\mathbf{u}). By adopting this deep learning approach, we aim to simplify the shape prediction process while maintaining high accuracy. This method has the potential to enhance endovascular navigation by providing real-time, accurate predictions of the guidewire shape, ultimately improving procedural outcomes and reducing reliance on specialized equipment.

Network Key Components: The figure illustrates the essential components of the proposed model. a) Spherical coordinates (r,θ,ϕ)(r, \theta, \phi) are used for predicting the guidewire shape. b) The model predicts the 3D shape of a guidewire from image sequences It\mathbf{I}_t. A Vision Transformer (ViT) extracts spatial features zt\mathbf{z}_t, which a Gated Recurrent Unit (GRU) processes to capture temporal dependencies, producing hidden states ht\mathbf{h}_t. The final hidden state drives three prediction heads: the Tip Prediction Head for the 3D tip position pR3\mathbf{p} \in \mathbb{R}^3, the Spherical Offset Prediction Head for coordinate offsets (Δϕ,Δθ)(\Delta \phi, \Delta \theta), and the Stop Prediction Head for terminal point probability S\mathbf{S}.

Spherical Coordinates Representation

Predicting 3D points directly can be challenging due to the high degree of freedom. To mitigate this, we use spherical coordinates, which offer significant advantages over Cartesian coordinates for guidewire shape prediction. Spherical coordinates, as represented in Fig. 1a, are defined by the radius rr, polar angle θ\theta, and azimuthal angle ϕ\phi. They provide a more natural representation for the position and orientation of points along the guidewire, which is typically elongated and curved.

Mathematically, a point in spherical coordinates (r,θ,ϕ)(r, \theta, \phi) can be converted to Cartesian coordinates (x,y,z)(x, y, z) using the transformations:

x=rsinθcosϕ,y=rsinθsinϕ,z=rcosθ.x = r \sin \theta \cos \phi, \quad y = r \sin \theta \sin \phi, \quad z = r \cos \theta.

This conversion simplifies the modeling of angular displacements and rotations, as spherical coordinates directly encode directional information.

Predicting angular displacements (Δθ,Δϕ)(\Delta \theta, \Delta \phi) relative to a known radius rr aligns with the physical constraints of the guidewire, facilitating more accurate and interpretable shape predictions. By predicting an initial point (tip position) and representing subsequent points as offsets in Δϕ\Delta \phi and Δθ\Delta \theta while keeping rr fixed, this method simplifies shape comparison and reduces the parameter space. This approach enhances the model’s ability to capture the guidewire’s spatial configuration and improves overall prediction performance.

Network Architecture

The proposed model (shown in Fig. 1b) addresses the problem of predicting the 3D shape of a guidewire from a sequence of images. Each image sequence captures the guidewire from different time steps IA,t\mathbf{I}_{A,t}, and the goal is to infer the continuous 3D shape Ct(ut)\mathbf{C}_t(\mathbf{u}_t). This many-to-one prediction task is akin to generating a variable-length sequence from variable-length input sequences, a technique commonly utilized in fields such as machine translation and video analysis.

To achieve this, the input pipeline consists of a sequence of images depicting the guidewire. A Vision Transformer (ViT), pre-trained on ImageNet, is employed to extract high-dimensional spatial feature representations from these images. The ViT generates feature maps ztR\mathbf{z}_t \in \mathbb{R}. These feature maps are then fed into a Gated Recurrent Unit (GRU) to capture the temporal dependencies across the image sequence. The GRU processes the feature maps zt\mathbf{z}_t from consecutive time steps, producing a sequence of hidden states ht\mathbf{h}_t. Formally, the GRU operation at time step tt is defined as:

ht=GRU(zt,ht1).\mathbf{h}_t = \text{GRU}(\mathbf{z}_t, \mathbf{h}_{t-1}).

The final hidden state ht\mathbf{h}_t from the GRU is used by three distinct prediction heads, each tailored for a specific aspect of the guidewire shape prediction: the Tip Prediction Head, responsible for predicting the 3D coordinates of the guidewire’s tip through a fully connected layer that maps the hidden state ht\mathbf{h}_t to a Cartesian anchoring point pR3\mathbf{p} \in \mathbb{R}^3; the Spherical Offset Prediction Head, which predicts the spherical coordinate offsets (Δϕ,Δθ)(\Delta \phi, \Delta \theta) for points along the guidewire with a fixed radius rr; and the Stop Prediction Head, which outputs the probability distribution indicating the terminal point of the guidewire by using a softmax layer to produce a probability tensor S\mathbf{S}, where each element Sj\mathbf{S}_j indicates the probability of the jj-th point being the terminal point.

Loss Function

The custom loss function for training the model combines multiple components to handle the point-wise tip error, variable guidewire length (stop criteria), and tip position predictions. The overall loss function Ltotal\mathcal{L}_{\text{total}} is defined as:

Ltotal=1Ni=1N(λtipp^ipi2+λoffset((ϕ^iϕi)2+(θ^iθi)2)+λstop(silog(s^i)(1si)log(1s^i)))\mathcal{L}_{\text{total}} = \frac{1}{N} \sum_{i=1}^N \bigg( \lambda_{\text{tip}} \left \| \hat{\mathbf{p}}_i - \mathbf{p}_i \right \|^2 + \lambda_{\text{offset}} \big( (\hat{\boldsymbol{\phi}}_i - \boldsymbol{\phi}_i)^2 + (\hat{\boldsymbol{\theta}}_i - \boldsymbol{\theta}_i)^2 \big) + \lambda_{\text{stop}} \big( -\mathbf{s}_i \log (\hat{\mathbf{s}}_i) - (1 - \mathbf{s}_i) \log (1 - \hat{\mathbf{s}}_i) \big) \bigg)

where NN is the number of samples, and λtip\lambda_{\text{tip}}, λoffset\lambda_{\text{offset}}, and λstop\lambda_{\text{stop}} are weights that balance the contributions of each loss component. The tip prediction loss (Ltip\mathcal{L}_{\text{tip}}) uses mean squared error (MSE) to ensure accurate 3D tip coordinates. The spherical offset loss (Loffset\mathcal{L}_{\text{offset}}) also uses MSE to align predicted and ground truth angular offsets, capturing the guidewire’s shape. The stop prediction loss (Lstop\mathcal{L}_{\text{stop}}) employs binary cross-entropy (BCE) to accurately predict the guidewire’s endpoint.

Training Details

The model was trained end-to-end using the loss from Equation above. The NAdam optimizer was used with an initial learning rate of 1×1041 \times 10^{-4}. Additionally, a learning rate scheduler was employed to adjust the learning rate dynamically based on the validation loss. Specifically, the ReduceLROnPlateau scheduler was configured to reduce the learning rate by a factor of 0.1 if the validation loss did not improve for 10 epochs. The model was trained for 400 epochs, with early stopping based on the validation loss to further prevent overfitting.

Next

In the next part, we will validate the effectiveness of Guidewire Shape Prediction dataset and methodology.

Guide3D A Bi-planar X-ray Dataset for 3D Shape Reconstruction (Part 2)

We propose the Guid3D Dataset, a comprehensive resource specifically designed to advance 3D reconstruction and segmentation in endovascular navigation. This dataset addresses key limitations in the field, such as the scarcity of high-quality, publicly accessible datasets, by providing a diverse collection of real and synthetic imaging data. Guid3D includes detailed annotations for guidewire and catheter segmentation, alongside multi-view fluoroscopic data that supports accurate 3D modeling. By offering a standardized platform for algorithm development and evaluation, Guid3D aims to bridge the gap between research and clinical practice, facilitating improvements in precision, visualization, and tool tracking during endovascular procedures. Through this dataset, we seek to accelerate innovation in medical imaging, contributing to safer and more effective interventions.

Data Collection Setup

X-ray System. Our setup employed a bi-planar X-ray system equipped with 60 kW Epsilon X-ray generators and 16-inch image intensifier tubes by Thales, featuring dual focal spot Varian X-ray tubes for high-definition imaging. The system included Ralco automatic collimators for precise alignment and exposure, with calibration achieved through the use of acrylic mirrors and geometric alignment grids.

Anatomical Models. We utilized a half-body vascular phantom model from Elastrat Sarl Ltd., Switzerland, enclosed in a transparent box and integrated into a closed water circuit to simulate blood flow. Made from soft silicone and equipped with compact continuous flow pumps, it replicates human blood flow dynamics. The design is based on detailed postmortem vascular casts, ensuring anatomical accuracy reflective of human vasculature, facilitating realistic vascular simulations.

Figure 1. Dataset Overview: Guide3D contains 8,746 manually annotated frames from two views for 3D reconstruction (left), from which the reconstruction is derived (right).

Surgical Tools. To enhance our dataset, we navigated complex vascular structures using two types of guidewires commonly used in real-world endovascular surgery. The first, the Radifocus™ Guide Wire M Stiff Type (Terumo Ltd.), is made from nitinol with a polyurethane-tungsten coating. It measures 0.89 mm in diameter and 260 cm in length, with a 3 cm angled tip, designed for seeking, dissecting, and crossing lesions. The second, the Nitrex Guidewire (Nitrex Metal Inc.), also made of nitinol, features a gold-tungsten straight tip for enhanced radiopacity in fluoroscopic visualization. It has a diameter of 0.89 mm and a length of 400 cm, with a 15 cm tip, and is generally used for accessing or maintaining position during catheter exchanges. Both guidewires were selected to reflect real-world usage and to diversify the data in our dataset.

Figure 2. Materials: a) Overall setup & endovascular phantom, b) Radifocus (angled) guidewire. and c) Nitrex (straight) guidewire.

Data Acquisition, Labeling, and Statistics

Using the materials described in Subsection 3.1, we compiled a dataset of 8,746 high-resolution samples (1,024 × 1,024 pixels). This dataset includes 4,373 paired instances, both with and without a simulated blood flow medium. Specifically, it consists of 6,136 samples from the Radifocus guidewire and 2,610 from the Nitrex guidewire, providing a solid foundation for automated guidewire tracking in bi-planar scanner images. Manual annotation was carried out using the Computer Vision Annotation Tool (CVAT), where polylines were created to accurately track the dynamic path of the guidewires. The polyline representation was chosen because the guidewire's structure often results in overlapping sections, making a segmentation mask unsuitable. In contrast, a polyline effectively captures the looping nature of the guidewire, offering greater accuracy.

As shown in Table 1, the dataset includes 3,664 instances of angled guidewires with fluid and 484 without, while straight guidewires are represented by 2,472 instances with fluid and 2,126 without. This distribution reflects a variety of procedural contexts. All 8,746 images in the dataset are accompanied by manual segmentation ground truth, facilitating the development of algorithms that require segmentation maps as reference data.

Table 1. Dataset Composition Overview.

Calibration

We extract the camera parameters using a traditional undistortion and calibration method. Undistortion is first achieved with a local weighted mean (LWM) algorithm, using a perforated steel sheet with a hexagonal pattern as a framing reference, and applying a blob detection algorithm to precisely identify distortion points. This approach establishes correspondences between distorted and undistorted positions, allowing for accurate distortion correction.

Following this, a semi-automatic calibration step is performed for marker identification, and the random sampling consensus (RANSAC) method is used to ensure robustness in computing the projection matrix and deriving the intrinsic and extrinsic camera parameters. The calibration process is further refined through direct linear transformation (DLT) and non-linear optimization, utilizing multiple poses of the calibration object to optimize the overall camera setup. Figure 3 illustrates the calibration process.

Figure 3. Fluoroscopic Calibration: a) Undistortion grid application, and b) Point identification on calibration frame.

Guidewire Reconstruction

Given polyline representations of a curve in both planes, the reconstruction process begins by parameterizing these curves using B-Spline interpolation. Each curve is expressed as a function of the cumulative distance along its path. Let CA(uA)\mathbf{C}_A(\mathbf{u}_A) and CB(uB)\mathbf{C}_B(\mathbf{u}_B) represent the parameterized B-Spline curves in their respective planes, where uA\mathbf{u}_A and uB\mathbf{u}_B are the normalized arc-length parameters. The corresponding uB\mathbf{u}_B for a given uA\mathbf{u}_A is found using epipolar geometry. Once the corresponding points CA(uAi)\mathbf{C}_A(\mathbf{u}_A^i) and CB(uBi)\mathbf{C}_B(\mathbf{u}_B^i) are identified, their 3D coordinates Pi\mathbf{P}^i are computed by triangulation, resulting in a set of 3D points {Pi}i=1M\{\mathbf{P}^i\}_{i=1}^{M}, where MM is the total number of sampled points. This effectively reconstructs the original curve in 3D space.

To retrieve the fundamental matrix F\mathbf{F}, which describes the relationship between points in Image A (IA\mathbf{I}_A) and Image B (IB\mathbf{I}_B), the condition xBTFxA=0\mathbf{x}_B^T \mathbf{F} \mathbf{x}_A = 0 must hold for corresponding points xA\mathbf{x}_A in IA\mathbf{I}_A and xB\mathbf{x}_B in IB\mathbf{I}_B. Using the projection matrices PA\mathbf{P}_A and PB\mathbf{P}_B derived from the calibration process, the fundamental matrix can be calculated as follows:

F=[eB]×PBPA+\mathbf{F} = [\mathbf{e}_B]_\times \mathbf{P}_B \mathbf{P}_A^+

Here, eB\mathbf{e}_B is the epipole in Image B, defined as eB=PBCA\mathbf{e}_B = \mathbf{P}_B \mathbf{C}_A, with CA\mathbf{C}_A being the camera center of PA\mathbf{P}_A. The skew-symmetric matrix of the epipole eB\mathbf{e}_B is represented by:

[eB]×=[0eB3eB2eB30eB1eB2eB10][\mathbf{e}_B]_\times = \begin{bmatrix} 0 & -e_{B3} & e_{B2} \\ e_{B3} & 0 & -e_{B1} \\ -e_{B2} & e_{B1} & 0 \end{bmatrix}

Where eB=(eB1,eB2,eB3)T\mathbf{e}_B = (e_{B1}, e_{B2}, e_{B3})^T, and PA+\mathbf{P}_A^+ is the pseudoinverse of the projection matrix PA\mathbf{P}_A. The fundamental matrix F\mathbf{F} encapsulates the epipolar geometry between the two views, ensuring that corresponding points xA\mathbf{x}_A and xB\mathbf{x}_B lie on their respective epipolar lines.

The matching phase begins by uniformly sampling points along the curve CA(uA)\mathbf{C}_A(u_A) at intervals ΔuA\Delta u_A. For each sampled point xA=CA(uA)x_A = \mathbf{C}_A(u_A), we project the epiline lB=FxAl_B = F x_A into Image B. We then determine the intersection of the epiline lBl_B with the curve CB(uB)\mathbf{C}_B(u_B), thereby obtaining the corresponding parameter uBu_B for each uAu_A.

Due to errors in the projection matrices PAP_A and PBP_B, there are instances where the epiline lBl_B does not intersect with any part of the curve CB\mathbf{C}_B. To address this, we fit a monotonic function fA(uA)uBf_A(u_A) \rightarrow u_B using a Piecewise Cubic Hermite Interpolating Polynomial (PCHIP), thus interpolating the missing intersections. The matching process is visualized in Fig. 4.

Figure 4.Point Matching Process. Sampled points from image IAI_A (CA(uA)\mathbf{C}_A(u_A)) and their corresponding epilines lAl_A on image IBI_B are matched with their counterparts CB(uB)\mathbf{C}_B(u_B). The epilines for CB(uB)\mathbf{C}_B(u_B) are then computed and displayed on image IAI_A.

Utility of Guide3D Dataset for the Research Community

Guide3D advances endovascular imaging by providing a bi-planar fluoroscopic dataset for segmentation and 3D reconstruction, serving as an open-source benchmark. It enables precise algorithm comparisons for segmentation and facilitates method development in 3D reconstruction through the use of bi-planar imagery. With video data, Guide3D supports video-based methods, leveraging temporal dimensions for dynamic analysis. This enriches the segmentation and reconstruction capabilities, while also aligning with the procedural nature of endovascular interventions. This versatility highlights Guide3D's pivotal role in advancing endovascular imaging.

Next

In the next part, we will explore Guidewire Shape Prediction methodology.

Guide3D A Bi-planar X-ray Dataset for 3D Shape Reconstruction (Part 1)

Endovascular surgical tool reconstruction represents an important factor in advancing endovascular tool navigation, which is an important step in endovascular surgery. However, the lack of publicly available datasets significantly restricts the development and validation of novel machine learning approaches. Moreover, due to the need for specialized equipment such as biplanar scanners, most of the previous research employs monoplanar fluoroscopic technologies, hence only capturing the data from a single view and significantly limiting the reconstruction accuracy.

To bridge this gap, we introduce, a bi-planar X-ray dataset for 3D reconstruction. The dataset represents a collection of high resolution bi-planar, manually annotated fluoroscopic videos, captured in real-world settings. Validating our dataset within a simulated environment reflective of clinical settings confirms its applicability for real-world applications. Furthermore, we propose a new benchmark for guidewrite shape prediction, serving as a strong baseline for future work. The proposal not only addresses an essential need by offering a platform for advancing segmentation and 3D reconstruction techniques but also aids the development of more accurate and efficient endovascular surgery interventions.

Introduction

Minimally invasive surgery has revolutionized endovascular interventions, offering less invasive options with shorter recovery times. The success of these procedures depends on the precise navigation and manipulation of instruments such as guidewires and catheters. Typically, 2D visualization methods are used for guidance, with monoplanar fluoroscopy being the most common due to its minimal disruption to surgical workflows and relatively affordable cost. However, despite their widespread use, conventional imaging techniques have significant limitations, with one of the primary challenges being the lack of depth perception. This issue complicates the accurate visualization of surgical instruments, increasing the risk of excessive contact with arterial walls, which can compromise patient safety and the effectiveness of the procedure.

In endovascular interventions, depth perception is largely achieved through multi-view imaging systems, such as biplanar scanners, which allow shape reconstruction by combining images from multiple angles and employing epipolar geometry-based reconstruction. However, two major challenges hinder the broader adoption and effectiveness of these systems: (i) the difficulty of accurately segmenting images for successful shape reconstruction, exacerbated by the scarcity of datasets needed to evaluate segmentation methods, and (ii) the limited availability of specialized biplanar scanners in clinical settings due to their high cost. These challenges underscore the critical need for comprehensive datasets to enhance segmentation algorithm accuracy and improve guidewire reconstruction techniques, facilitating the development of more versatile imaging technologies.

Figure 1. Guide3D dataset contains 8,746 manually annotated frames from two views for 3D reconstruction.

In this paper, we introduce Guid3D, a dataset designed to advance 3D reconstruction in endovascular navigation. Guid3D provides a standardized platform for the development and evaluation of algorithms. With a comprehensive dataset that includes manual annotations for segmentation and tools for effective 3D visualization, Guid3D is intended to drive innovation and improvement in endovascular intervention. Furthermore, the inclusion of video-based biplanar fluoroscopic data allows for the exploration of temporal dynamics, such as using optical flow networks. Guid3D seeks to bridge the gap between research innovations and clinical applications, addressing key challenges in endovascular procedures.

Related Works

Endovascular Datasets.

Datasets play a crucial role in advancing endovascular navigation by providing essential resources for the development, evaluation, and enhancement of algorithms. These datasets, derived from various imaging modalities such as mono X-ray, 3D ultrasound, and 3D MRI, encompass both real and synthetic images, facilitating diverse applications in the medical field.

Mono X-ray datasets, while prevalent, often fall short in providing the necessary detail required for accurate 3D reconstruction, which is critical for effective surgical navigation. The inherent limitations of 2D imaging techniques make it challenging to fully capture the complexity of anatomical structures during procedures. In contrast, 3D imaging modalities like 3D ultrasound and 3D MRI offer more comprehensive views, enabling better depth perception and improved visualization of surgical tools and surrounding tissues.

Despite the importance of these datasets, there remains a significant gap in the availability of comprehensive, publicly accessible datasets specifically designed for tool segmentation and 3D reconstruction. This scarcity hampers progress in developing robust algorithms capable of accurately interpreting complex medical images. The lack of diverse and high-quality datasets also limits the ability of researchers to train and validate their algorithms effectively, often leading to suboptimal performance in clinical scenarios.

Furthermore, creating high-quality datasets is not merely a technical challenge; it requires collaboration among various stakeholders, including clinicians, radiologists, and data scientists. Such collaboration is essential to ensure that the datasets reflect real-world clinical conditions and include diverse patient populations. Expanding the availability of well-annotated datasets is vital for fostering innovation and advancing the field of endovascular surgery.

Figure 2. Endovascular Dataset Explaination.

Catheter and Guidewire Segmentation.

The segmentation of endovascular tools, particularly guidewires and catheters, is an evolving field that heavily relies on the availability and quality of datasets. Previous studies have often used synthetic and semi-synthetic data to address the challenges posed by the limited availability of real-world datasets. Researchers have employed manually annotated datasets from 2D X-ray and 3D MRI modalities to train segmentation models. Additionally, the effectiveness of synthetic datasets has been demonstrated in improving model efficiency.

The advent of deep learning techniques, especially U-Net architectures, has significantly enhanced the accuracy of segmentation and tracking for these surgical instruments. This advancement has led to the development of fully automated segmentation frameworks that utilize extensively annotated data and incorporate unsupervised techniques, such as optical flow. However, the absence of a public, standardized dataset for method comparison continues to impede the advancement and assessment of scientific progress in this area.

Figure 3. Interventional Microcatheters.

3D Reconstruction.

Improving the accuracy of 3D reconstruction in endovascular procedures plays a crucial role in achieving better clinical outcomes by enhancing catheter navigation through advanced visualization and precise tracking. Advances in fluoroscopic imaging technology have led to more accurate positioning of devices. Various algorithms have been developed to facilitate this process, employing techniques such as elastic grid registration and epipolar geometry for 3D reconstruction from biplane angiography. Additionally, automatic catheter detection methods utilizing triangulation and graph-search algorithms have been applied in electrophysiology studies to improve reconstruction outcomes.

Research has demonstrated the importance of accurate 3D models for navigation within both complex and single-view vascular architectures, highlighting the value of biplanar data. However, the limited availability of comprehensive, publicly accessible datasets for the development and validation of algorithms in 3D reconstruction poses a significant challenge to technological progress and clinical application. This situation underscores the critical need for specialized datasets to promote ongoing innovation in the reconstruction of endovascular tools.

Figure 4. Guideware Calibration.

Next

In the next part, we will dive deeply into how to conduct a dataset for 3D shape reconstruction.

Scalable Group Choreography via Variational Phase Manifold Learning (Part 4)

In the previous part, we introduce training process and experimental setup. In this part, we validate the effectiveness and efficiency of the proposed method.

Figure 1. We present a new group dance generation method that can generate a large number of dancers within a fixed resource consumption. The illustration shows a generated group dance sample with 100100 dancers.

Experimental Results

Quality Comparison

Table 1 presents a comparison among the baselines FACT, Transflower, EDGE, GDanceR, GCD, and our proposed manifold-based method. The results clearly demonstrate that our model outperforms the baselines across all evaluations on two datasets, AIOZ-GDANCE and AIST-M. We observe that recent diffusion-based dance generation models, such as EDGE or GCD, achieve competitive performance on both single-dance metrics (FID, MMC, GenDiv, and PFC) and group dance metrics (GMR, GMC, and TIF). However, due to limitations in their training procedures, they still struggle with generating multiple dancing motions when faced with a large number of dancers, as indicated by their lower performance compared to our method. This suggests that our approach successfully maintains the quality of dance motions as the number of dancers increases.

Table 1. Performance comparison.
Additionally, Figure 2 illustrates that our proposed method outperforms other state-of-the-art models like GDanceR and GCD in addressing issues such as monotonous, repetitive, sinking, and overlapping dance motions.
Figure 2. Visualization of a dancing sample between different methods. GDanceR displays monotonous, repetitive, or sinking dance motions. GCD exhibits more divergence in dance motions, yet dancers may intersect since their optimization does not address this issue explicitly. Blue boxes mark these issues. In contrast, our manifold-based solution ensures the divergence of dancing motions, while the phase motion path demonstrates its effectiveness in addressing floating and crossing issues in group dances.

Scalable Generation Analysis

Table 2 illustrates the performance of different group dance generation methods (GCD, GDancer, and Ours) when generating dance movements with an increasing number of dancers. When the number of dancers is increased to 10, GCD appears to run out of memory, which is also observed in GDanceR when the number of dancers increases to 100. Regardless of the number of dancers, our method consistently achieves stable and competitive results. This implies that our proposed method successfully addresses the scalability issue in group dance generation without compromising the overall performance of each individual dance motion.

Table 2. Performance of group dance generation methods when we increase the number of generated dancers. The experiments are done with common consumer GPUs with 4GB memory. (N/A means models could not run due to inadequate memory footprint).

Figure 3 illustrates the memory consumption to generate dance motions in groups for each method. Noticeably, our proposal still achieves the highest performance while consuming immensely fewer resources required for generating group dance motions (See Figure 4 for illustrations). This, again, indicates that our method successfully breaks the barrier of limited generated dancers by using the manifold.

Figure 3. Memory usage vs. number of dancers in different dance generators.

Figure 4. Visualization of a dancing sample between different methods. GDanceR displays monotonous, repetitive, or sinking dance motions. GCD exhibits more divergence in dance motions, yet dancers may intersect since their optimization does not address this issue explicitly. Blue boxes mark these issues. In contrast, our manifold-based solution ensures the divergence of dancing motions, while the phase motion path demonstrates its effectiveness in addressing floating and crossing issues in group dances.

Ablation Study

Table 3 presents the performance improvements achieved through the integration of consistency loss and phase manifold. Additionally, we showcase the effectiveness of our proposed approach across three different backbones: Transformer, LSTM, and CNN. Evaluation metrics including FID, GMR, and GMC are utilized. The results indicate that the absence of consistency loss leads to an increase in GMR and a decrease in GMC, suggesting a significant enhancement in the realism and correlation of group dance motions facilitated by the inclusion of the proposed objective. Meanwhile, with out the phase manifold, the model exhibits remarkably higher scores in both the FID and GMR metrics, suggesting that phase manifold can effectively improve the distinction in dance motions while maintaining the realism of group dances, even when the number of dancers in a group is large. In comparing three backbones—Transformer, LSTM, and CNN—we have observed that the chosen Transformer backbone achieved the best results compared to LSTM or CNN.

Table 3. Module contribution and loss analysis.

User Study

User studies are vital for evaluating generative models, as user perception is pivotal for downstream applications; thus, we conducted two studies with around 70 participants each, diverse in background, with experience in music and dance, aged between 20 to 40, consisting of approximately 47\% females and 53\% males, to assess the effectiveness of our approach in group choreography generation.

In the user study, we aim to assess the realism of sample outputs with more and more dancers. Specifically, participants assign scores ranging from 0 to 10 to evaluate the realism of each dance clip with 2 to 10 dancers. Figure 5 shows that, across all methods, the more the number of dancers is increased, the lower the realism is found. However, the drop in realism of our proposed method is the least compared to GCD and GDanceR. The results indicate our method's good performance compared to other baselines when the number of dancers increases.

Figure 5. Realism between different methods when number of dancers is varied.

Discussion

While our approach leverages the VAE as a primary solution for generating a manifold, it is important to acknowledge certain inherent limitations associated with this choice. One notable challenge is the susceptibility to issues such as posterior collapse and unstable sampling within the VAE framework. These challenges can result in generated group dance motions that may not consistently meet performance expectations.

One specific manifestation of this limitation is the potential for false decoding when sampling points that lie too far from the learned distribution. This scenario can lead to unexpected rotations or disruptions in the physics of the generated content. The impact of this problem becomes evident in instances where the generated samples deviate significantly from the anticipated distribution, introducing inaccuracies and distortions.

To address these challenges, we recognize the need for ongoing efforts to mitigate the effects of posterior collapse and unstable sampling. While the problem is acknowledged, our approach incorporates measures to limit its impact. Future research directions could explore alternative generative models or additional techniques to enhance the robustness and reliability of the generated results in the face of these identified limitations.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 8)

Welcome to the concluding part of our series! In this final installment, we'll delve deeper into the intricacies of the search system that accompanies our proposal. Additionally, we'll conduct a comprehensive evaluation to gauge the overall effectiveness and performance of our solution. Let's explore the various aspects of the search functionality and analyze how it contributes to the success of our proposed system. Let's embark on this journey together as we wrap up our discussion and reflect on the insights gained throughout this series.

1. Searching and Indexing Method in the CFOR System

To adapt our retrieval system for large-scale datasets, we've developed indexes for the CFOR system to facilitate nonexhaustive similarity search using GPU acceleration. Leveraging the searching algorithm outlined by Johnson et al. (referenced as "billion-scale similarity search with GPUs"), we've implemented it within the CFOR system for retrieval tasks. In searching, the CFOR system enhances accuracy by reducing the search space through additional information such as regions, categories, and attributes. For indexing, the object ontology aids in creating multi-indexing files to minimize search time. Our focus is on similarity search in vector collections using the L2 distance in the k-selection algorithm.

In the realm of searching, we distinguish between exact search (exhaustive search) and compressed search (greedy nonexhaustive search).

  • Exact search: Almost all searching algorithms in this type try to compute the full pairwise distance between the query and each data point in the database sequentially or using the index file.
  • Compressed-Domain search: Almost all searching algorithms in this type try to compute distance between the query and each data point in the database by applying space transformation, encoding, subspace splitting, or hashing. These methods can help improve searching time by using index files, but they have a trade-off in searching accuracy.

2. Data

Our fashion retrieval system was built on a subset of approximately 300,000 images of DeepFashion. In the DeepFashion dataset, objects from different aspects are caught in complicated background. The input image in the dataset is annotated with different labels based on details (fine-grained) of input of the current model concern, i.e., rich annotation. The samples given in Figures 1 and 2 show more details about the DeepFashion dataset.

Figure 1. Images from the DeepFashion dataset obtained from different views and complicated background.

Figure 2. Images from the DeepFashion dataset annotated with different labels based on details of input of the current model concern.

In testing, we employ part of the benchmark data to fine tune the trained models. We ensure that there are no fashion item overlaps between fine-tuning and testing sets. The dataset includes ∼220,000 images of the training set, 40,000 images of the validating set, and 40,000 images of the testing set split. However, in attribute learning, we limited the number of attribute labels used for testing and the number of training images for specific attributes to make an imbalanced attribute dataset so as to prove our proposed methods.

3. Results and Discussion

In the CFOR system, object ontology is useful in controlling training flow which impacts the performance of object category classification and attribute multitask classification. For object category classification, ontology controls the amount of training data through concepts. For attribute multitask classification, ontology manages local grouping which directly affects the performance of the proposed local imbalanced data solver on the large-scale dataset.

In this section, we will evaluate the effectiveness of different deep networks with the support of ontology on both category classification and attribute multitask classification in the CFOR system to pick out the best architecture for training the system. We will also compare our results with FashionNet.

Category Classification

We compare the performance between different deep architectures including NASNet, ResNet-18, ResNet-101, FashionNet, NASNet with average pooling dropout (NASNet APD) (proposed by us), and ResNet with average pooling dropout (ResNet APD) (proposed by us). These experiments will be evaluated by top-k accuracy (Figure.3). Our target is to find out the best possible architecture to apply as a core network of the CFOR system. This step can be mentioned as a preparation step before applying the CFOR system for fashion retrieval.

Figure 3. Accuracy plot for top-k accuracy in category classification.

The result of category classification by ResNet-18 APD is higher than 1.23% (at k= 1) after removing nodes and making average pooling in the ResNet-18 architecture (compared with the original ResNet-18 architecture). This increased value is 0.93% with the ResNet-101 architecture (compared with the original ResNet-101 architecture) and 0.02% with the NASNet v3 architecture (compared with the original NASNet v3 architecture). The ResNet-101 APD architecture (the best architecture addressed) outperformed the FashionNet architecture (the best performing architecture in category classification on the DeepFashion dataset versus others such as WTBI or DARN), and the value is 4.6% with k= 3 and 2.58% with k = 5. Based on the above experimental results, the ResNet-101 architecture provides better classification and higher performance compared to others (NASNet and ResNet-18). For this reason, we propose ResNet-101 as the core network architecture for training classification models.

Attribute Learning

Attribute multitask learning is an important part of the CFOR system. In this section, we evaluate the performance of the proposed local imbalanced data solver with MCC in dealing with the imbalanced attribute data on the large-scale fashion dataset.

Precision is the proportion of relevant instances among the retrieved instances which consider both true positives and false positives in each attribute. However, the number of true positives and false positives is bias because of the imbalanced data problem. Thus, precision can also be affected by the imbalanced data problem. Otherwise, recall, which cares about true-positive labels but not false-positive labels, will be used to evaluate experiments because of its good reflection for fewer data attributes.

Local MTL gets over STL and MTL in 28/35 attributes with a 54.70% recall rate (higher than that in STL (17.06%) and that in MTL (28.70%)). While a single task shows its weakness in fewer data attributes and multitasks get struggled with the serious imbalanced problem and lesser intergroup correlations in fashion data, local MTL can lower their negative influences as well as widen the positive effect of inner-group correlations on attribute learning. Thus, local MTL gets over STL and MTL in 13/15 fewer sample attributes (Figure.4).

Figure 4.Recall graph of 14 attributes in STL and local MTL.

Based on the experiment, comparison of chic, solid, and maxi attributes which have equal accuracy between MTL with and without MCC shows that MTL with MCC had higher recall compared to that without MCC in 20/35 remaining attributes. The overall performance increases about 3%. For attributes with fewer data, MTL with MCC had higher recall compared to that without MCC in 9/14 attributes. The overall performance for these fewer data attributes increases 5.14% (see Figure 5 for more details).

Figure 5.Recall graph of 35 attributes using local multitask models with and without MCC.

Retrieval in CFOR System Results

In this experiment, we test the retrieval ability of the CFOR system by using MAP from 1 retrieval result for each query (MAP@1) to 30 retrieval results for each query (MAP@30) so as to evaluate the effectiveness. The similarity retrieval experiment will check whether the extracted attributes in retrieved images are matched with ground-truth attributes in query image. The retrieval method will be based on deep features and over 35 attributes. After experimenting in 35 attributes belonging to 5 groups, the starting MAP@5 is acceptable (hovering 0.531) which shows the effectiveness of the searching method. The MAP@30 hovers 0.815, and the trend keeps rising which shows consistency and stabilization of information prediction methods in the CFOR system. A simple visualization of the retrieval process in the CFOR system is shown in Figure.6.

Figure 6. An example of the retrieval results of the CFOR system.

4. Conclusion

This work presents the coarse-to-fine object retrieval system, a learning framework for e-commerce online retrieval, which is supported to deal with large-scale imbalanced datasets. The framework can impact input and output as well as reconstruct datasets from the coarse-grained level to the fine-grained level and is believed to be an effective method in improving learning performance designed for retrieval. For input reconstruction, the framework based on ontology is used for threading training flow, local grouping in multitask attribute learning, and hierarchical storage and retrieval. For output optimization, we take advantage of MCC to minimize the effect of the imbalanced dataset on multitask attribute learning.

Through extensive experiments, we demonstrate the applicability of object ontology in improving training flow, the effectiveness of different deep networks (ResNet and NASNet) applied on important tasks in fine-grained retrieval, and the usefulness of local multitask attribute learning and an MCC-based imbalanced data solver in attribute multitask learning. The CFOR system is designed to have flexibility so that it can be optimized easily in the future.

Scalable Group Choreography via Variational Phase Manifold Learning (Part 3)

In the previous part, we introduct our main proposal Variational Phase Manifold Learning. In this part, we have explore the training process and experimental setup to validate the proposed method.

Figure 1. We present a new group dance generation method that can generate a large number of dancers within a fixed resource consumption. The illustration shows a generated group dance sample with 100100 dancers.

Training

During training, we consider the following variational lower bound to mainly train our dance generation VAE model:

logpθ(xa)Eqϕ[logpθ(xz,a)]DKL(qϕ(zx,a)pθ(za))\log p_\theta(\mathbf{x}|\mathbf{a}) \geq \mathbb{E}_{q_\phi} \left[ \log p_\theta(\mathbf{x}|\mathbf{z},\mathbf{a}) \right] - D_{\text{KL}}(q_\phi(\mathbf{z}|\mathbf{x},\mathbf{a}) \Vert p_\theta(\mathbf{z}|\mathbf{a}))

In practice, we apply the conditional VAE loss, which is defined as the weighted sum Lcvae=Lrec+λKLLKL\mathcal{L}_{\text{cvae}} = \mathcal{L}_{\text{rec}} + \lambda_{\text{KL}}\mathcal{L}_{\text{KL}}. In particular, the reconstruction term Lrec\mathcal{L}_{\text{rec}} measures the motion reconstruction error given the decoder output (via a smooth-L1 loss). The KL divergence term LKL\mathcal{L}_{\text{KL}} compares the divergence DKLD_{\text{KL}} between the posterior and the prior distribution to enforce them to be close to each other.

The conditional VAE objective above is calculated for each dancer separately and cannot capture the correlation between dancers within a group. Therefore, it is essential to impose consistency among dancers and avoid strange effects such as unsynchronized dance. To this end, we propose a group consistency loss by enforcing the latent phase manifold to be similar for the same group, given the input music. Specifically, we first calculate the phase manifold features based on the frequency domain parameters as follows:

P2i1=Aisin(2πSi),P2i=Aicos(2πSi)\mathbf{P}_{2i-1} = \mathbf{A}_i\sin(2\pi \cdot \mathbf{S}_i), \qquad \mathbf{P}_{2i} = \mathbf{A}_i\cos(2\pi\cdot \mathbf{S}_i)

where PR2D\mathbf{P}\in\mathbb{R}^{2D} is the phase manifold vector that encodes the spatial-temporal information of the motion state. The phase feature may look similar to the positional encodings of transformers in the sense that both use multi-resolution sinusoidal functions. However, the phase feature is much richer in terms of representation capacity since it learns to embed the spatial (body joints) and temporal (positions in time) information of the motion curves, whereas the positional encodings only encode the position of words. Finally, our consistency objective is to constrain phase manifold between dancers within a group to be as close as possible to each other, which is formulated as:

Lcsc=DKL(qϕ(zxm,a)(qϕ(zxn,a))+PmPn22\mathbf{\mathcal{L}}_{\text{csc}} = D_{\text{KL}}(q_\phi(\mathbf{z}|\mathbf{x}^m,\mathbf{a}) \Vert (q_\phi(\mathbf{z}|\mathbf{x}^n,\mathbf{a}) ) + \Vert \mathbf{P}^m - \mathbf{P}^n\Vert^2_2

where the first term encourages the network to map different dancers belonging to the same group (xm\mathbf{x}^m and xn\mathbf{x}^n) into the same set of distributional phase parameters while the second term penalizes the discrepancy in their corresponding phase manifolds. In general, this loss is applied to ensure every dancer is embedded into a single unified manifold that can effectively represent their corresponding group. To summarize, our total training loss is defined as the combination of the VAE loss and the consistency loss L=Lcvae+λcscLcsc\mathcal{L} = \mathbf{\mathcal{L}}_{\text{cvae}} + \lambda_{\text{csc}}\mathbf{\mathcal{L}}_{\text{csc}}.

For testing, we can efficiently generate motions for an unlimited number of dancers by indefinitely drawing samples from the learned continuous group-consistent phase manifold. It is noteworthy that for inference, we only need to process the prior network once to obtain the latent distribution. To generate a new motion, we can sample from this latent (Gaussian) distribution and use the decoder to decode it back to the motion space. This approach is much more efficient and has significantly higher scalability than previous approaches that is limited by the number of dancers processed simultaneously by the entire network.

Figure 2. An example output.

Experiments

Implementation Details

Our model was trained on 4 NVIDIA V100 GPUs using Adam optimizer with a fixed learning rate of 10410^{-4} and a mini-batch size of 32 per GPU. For training losses, the weights are empirically set to λKL=5×104\lambda_{\text{KL}} = 5\times 10^{-4} and λcsc=104\lambda_{\text{csc}} = 10^{-4}, respectively. The Transformer encoders and decoders consist of 5 layers for both encoder, decoder, and prior Network with 512-dimensional hidden units. Meanwhile, the number of latent phase channels is set to 256. To further capture the periodic nature of the phase feature, we also use Siren activation following the initialization scheme. This can effectively model the periodicity inherent in the motion data, and thus can benefit motion synthesis.

Experimental Settings

Dataset. In our experiments, we utilize the AIOZ-GDance and AIST-M datasets. AIOZ-GDance is the largest music-driven dataset focusing on group dance, encompassing paired music and 3D group motions extracted from in-the-wild videos through a semi-automatic process. This dataset spans 7 dance styles and 16 music genres. For consistency, we adhere to the training and testing split during our experiments.

Evaluation Protocol. We employ several metrics to assess the quality of individual dance motions, including Frechet Inception Distance (FID), Motion-Music Consistency (MMC, and Generation Diversity (GenDiv), along with the Physical Foot Contact score (PFC). Specifically, the FID score gauges the realism of individual dance movements concerning the ground-truth dance. MMC assesses the matching similarity between motion and music beats, reflecting how well-generated dances synchronize with the music's rhythm. GenDiv is computed as the average pairwise distance of kinetic features among motions. PFC evaluates the physical plausibility of foot movements by determining the agreement between the acceleration of the character's center of mass and the foot's velocity.

In assessing the quality of group dance, we adopt three metrics : Group Motion Realism (GMR), Group Motion Correlation (GMC), and Trajectory Intersection Frequency (TIF). Broadly, GMR gauges the realism of generated group motions in comparison to ground-truth data, employing Frechet Inception Distance on extracted group motion features. GMC evaluates the synchronization among dancers within the generated group by computing their cross-correlation. TIF quantifies the frequency of collisions among the generated dancers during their dance movements.

Baselines. Our method is subjected to comparison with various recent techniques in music-driven dance generation, namely FACT, Transflower, and EDGE. These approaches are adapted for benchmarking within the context of group dance generation, considering that their original designs were tailored for single-dance scenarios. Additionally, our evaluation includes a comparison with GDanceR, GCD, and DanY. All of the mentioned works are specifically designed for the generation of group choreography.

Figure 3. Generated Group Dance from GCD baselines

Next

In the next part, we will explore the effectiveness of the proposal through quantitative and qualitative results.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 7)

To provide fine-grained information to the CFOR system, attribute learning is a most important task which should be optimized in both time-processing performance and ability to deal with large-scale imbalanced datasets.

1. Framework

Local multitask learning is applied to attribute learning. The proposed framework, depicted in Figure.1 and Figure.2 and comprising online and offline phases, consists of three main components. The initial component introduces a local multitask transfer learning model with a loss function designed to leverage inner-group correlations among attributes. The second component presents an imbalanced data resolver based on MCC (Matthews Correlation Coefficient) without any adjustments to the pretrained model or loss function. The third component discusses prior knowledge used for local attribute grouping to facilitate local multitask learning.

Figure 1. Local MTL with an imbalanced data problem solver framework (Offline phase).

The input and output of the learning framework will be images and their attribute vectors, respectively. However, with the local grouping role, the attribute vector’s size will be based on the number of attributes in each group. The dataset should be merged or split based on the local grouping role.

To evaluate the effectiveness of the proposed framework, we apply it in the fashion field and split the dataset into five local groups: fabric, shape, part, style, and texture. Because fashion has lesser intergroup correlations, the shared block should be designed to optimize the effectiveness of inner-group correlations to improve the overall performance. However, in crowd attributes (such as activities, locations, and participants), intergroup correlations should be taken into account to improve performance. Thus, the shared block should be modified to adapt to the context.

Figure 2. Local MTL with an imbalanced data problem solver framework (Online phase).

2. Deep Multitask Learning

Our aim is to estimate a number of fashion attributes via a joint estimation model. However, with the dynamic attributes, MTL which supports creating a joint estimation model becomes vulnerable in the training phase due to its nonusability when the number of attributes increases. Thus, the local grouping method can help solve this situation.

Framework in Detail

In the experiments, the proposed framework processes the query image and generates a confident score vector comprising 7 attribute scores per group across 5 groups, which is then thresholded to obtain binary outputs. The architecture is outlined in detail below.

Figure.1 illustrates the overall structure of the proposed method. For each group, a training set is assumed, consisting of NN fashion images, each with MM attributes. The dataset is represented as D=(Xi,Yi)D = {(X_i, Y_i)}, where XiX_i is the image and YiY_i is the label encoded as a one-hot vector. Inspired by prior researches, we employ an end-to-end DNN architecture as a shared block to learn joint representations for all tasks. The loss function employed is binary cross-entropy, and the activation function at the output layer is sigmoid, chosen for its simplicity and flexibility in modifying the DNN architecture.

Network Architecture

NASNet automatically generates network architectures, constructing an optimal model by initially creating architectures on a smaller dataset and then scaling them up to a larger one. Through experiments, the search for the best cells is conducted on the CIFAR-10 dataset, which are subsequently applied to the ImageNet dataset by stacking multiple copies of them, each with their own parameters. The resulting model demonstrates a 1.2% improvement in top-1 accuracy compared to the best human-designed architectures. NASNet proves its effectiveness over previous architectures and offers a transfer learning model trained on a diverse ImageNet dataset. Leveraging the pretrained NASNet model on ImageNet, transfer learning is applied to the DeepFashion dataset to expedite convergence and enhance performance. Additionally, a dropout layer is incorporated with NASNet to mitigate overfitting. While utilizing the NASNet model generation algorithm to tailor a model for the DeepFashion dataset is a promising approach, the time and hardware resources required for NASNet's model generation and training from scratch are significant. Due to hardware limitations, only transfer learning is employed.

Figure 3. Best normal cells and reduction cells identified with CIFAR-10 and ImageNet architecture (right) are built from the best convolutional cells . Zoph et al. built two types of cells because they want to create architectures for images of any size. While normal cells return a feature map which has the same dimension, reduction cells return a feature map with height and width reduced by a factor of two.

We will do experiments on NASNet architectures to find out which one is suitable for each specific task in our CFOR system. In our fashion retrieval experiments, the category classifier task and region classifier task are applied transfer with single-task learning, while fashion attribute recognition is applied local multitask learning. Besides, to adapt to large-scale datasets and reduce the effect of overfitting, we recommend changing the final fully connected layer to the global average pooling layer along with dropout.

Next

In the next post, we will discover Searching and Indexing Method in the CFOR System as well as the effectiveness of the whole system.

Scalable Group Choreography via Variational Phase Manifold Learning (Part 2)

In the previous part, we have explore the introduction about group dance scalability and the what is manifold. In this part, we introduct our main proposal Variational Phase Manifold Learning.

Figure 1. We present a new group dance generation method that can generate a large number of dancers within a fixed resource consumption. The illustration shows a generated group dance sample with 100100 dancers.

Task Definition

Given an input music sequence a={a1,at,...,aT}\mathbf{a} = \{a_1, a_t, ...,a_T\} with t={1,...,T}t = \{1,..., T\} indicates the index of the music frames, our goal is to generate the group motion sequences of NN arbitrary dancers: x={x11,...,xT1;...;x1N,...,xTN}\mathbf{x} = \{x^1_1,..., x^1_T; ...;x^N_1,...,x^N_T\} where xtnx^n_t is the pose of nn-th dancer at frame tt. We use the 6D continuous rotation for every joint, along with 3D joint positions and velocities. Additionally, the corresponding 3D root translation vectors are concatenated into the pose representations to involve the trajectory of motion. Previous group dance methods, which generate the whole group at once, cannot deal with the increasing number of dancers and can only create group sequences up to a pre-defined number of dancers, due to the vast complexity of the architecture. In contrast, we aim to generate group dance with an unlimited number of dancers.

Figure 2. An example output.

Phase-conditioned Dance VAE

Our goal is to learn a continuous manifold such that the motion can be generated by sampling from this learned manifold. We assume that although different dancers within the same group may present visually distinctive movements, the properties of their motions, such as timing, periodicity, or temporal alignment are intrinsically similar. We aim to learn a generative phase representation for each group of dancer in order to synthesize their motion indefinitely. Our generative model is built upon the conditional Variational Autoencoder architecture, thanks to its diverse generation capability and fast sampling speed. However, instead of directly encoding the data into a Gaussian latent distribution as in common VAE approaches, we model the latent variational distribution by the phase parameters extracted from the latent motion curve, which we call variational phase manifold. The latent phase manifold is well-structured and can well describe key characteristics of motion (such as its timing, local periodicity, and transition), which benefits learning motion features.

The overview of our Phase-conditioned Dance VAE is illustrated in Figure 3. Specifically, the model contains three main networks: an encoder E\mathcal{E} to capture the approximate posterior distribution conditioned on both motion and music qϕ(zx,a)q_\phi(\mathbf{z}|\mathbf{x},\mathbf{a}), a prior network P\mathcal{P} to learn the conditional prior given only the music pθ(za)p_\theta(\mathbf{z}|\mathbf{a}) , and a decoder D\mathcal{D} to learn to reconstruct the data from the latent distribution pθ(xz,a)p_\theta(\mathbf{x}|\mathbf{z},\mathbf{a}). The new motion is generated by sampling the frequency-domain parameters predicted by the prior network, which is then passed through the decoder network to reconstruct the motion in the original data space. Furthermore, we adopt Transformer-based architecture in each network to effectively capture long-range dependencies and holistic context of the dance sequence.

Figure 3. Overview of our Phase-conditioned Dance VAE (PDVAE) for scalable group dance generation. It consists of an Encoder, a Prior, and a Decoder network. During training, we encode the corresponding motion and music inputs into a latent phase manifold, which is variational and parameterized by the frequency domain parameters of periodic functions. The latent phases can be sampled from the manifold and then decoded back to the original data space to obtain new motions. The consistency loss Lcsc\mathcal{L}_{\text{csc}} is further imposed to constrain the manifold to be consistently unified for dancers that belong to the same group. At inference stage, only the Prior and the Decoder are used to synthesize group dances efficiently. .

Encoder

The encoder E\mathcal{E} is expected to take both the motion and music feature sequence as input, and produce a distribution over possible latent variables capturing the cross-modal relationship between them. To transform the joint input space into a learned phase manifold, we adopt the Transformer decoder architecture where the Cross-Attention mechanism is utilized to learn the relationship between the motion and the music. Accordingly, the output of the encoder is a batch of latent curves (i.e., the activation sequences per channel) that can particularly capture different spatial and temporal aspects of the motion sequence. However, instead of training the model to directly reconstruct the input motion from the extracted latent curves, we further enforce each channel of the latent space to have a periodic functional form (i.e., sinusoidal). This enables us to effectively learn a compact parameterization for each latent channel from a small set of parameters in the frequency domain.

Generative Variational Phase Manifold

Here we focus on learning the periodicity and non-linear temporal alignment of the motion in the latent space. In particular, given the output latent curves from the encoder L=E(x,a)RD×T\mathbf{L} = \mathcal{E}(x,a) \in \mathbb{R}^{D \times T} with DD is the number of desired phase channels to be extracted from the motion, we parameterize each latent curve in L\mathbf{L} using a sinusoidal function with amplitude (A\mathbf{A}), frequency (F\mathbf{F}), offset (B\mathbf{B}) and phase shift (S\mathbf{S}) parameters. To allow for variational phase manifold learning, we opt to predict two sets of parameters μE={μA;μF;μB;μS}\mathbf{\mu}_{\mathcal{E}} =\{\mathbf{\mu}^A; \mathbf{\mu}^F; \mathbf{\mu}^B; \mathbf{\mu}^S \} and σE={σA;σF;σB;σS}\mathbf{\sigma}_{\mathcal{E}} =\{\mathbf{\sigma}^A; \mathbf{\sigma}^F; \mathbf{\sigma}^B; \mathbf{\sigma}^S \}, which corresponds to the mean and variance of R4D\mathbb{R}^{4D} dimensional Gaussian distribution:

qϕ(zx,a)=N(z;μE,σE)q_\phi(\mathbf{z}|\mathbf{x},\mathbf{a}) = \mathcal{N}(\mathbf{z};\mathbf{\mu}_{\mathcal{E}}, \mathbf{\sigma}_{\mathcal{E}})

To do so, we first apply differentiable Fast Fourier Transform (FFT) to each channel of the latent curve L\mathbf{L} and create the zero-indexed matrix of Fourier coefficients as c=FFT(L)\mathbf{c}=FFT(\mathbf{L}) with cCD×K+1\mathbf{c} \in \mathbb{C}^{D \times K+1}, K=T2K =\lfloor \frac{T}{2}\rfloor. Correspondingly, we compute the per channel power spectrum pRD×K+1\mathbf{p} \in \mathbb{R}^{D \times K+1} as pi,j=2Nci,j2\mathbf{p}_{i,j} = \frac{2}{N}|\mathbf{c}_{i,j}|^2, where ii is the channel index and jj is the index for the frequency bands. Correspondingly, the distributional mean parameters of the periodic sinusoidal function are then calculated as follows:

μiA=2Tj=1Kpi,j,μiF=j=1Kfjpi,jj=1Kpi,j,μiB=ci,0T,\mathbf{\mu}^A_i = \sqrt{\frac{2}{T}\sum_{j=1}^K \mathbf{p}_{i,j}}, \quad \mathbf{\mu}^F_i = \frac{\sum_{j=1}^K \mathbf{f}_j \cdot \mathbf{p}_{i,j}}{ \sum_{j=1}^K \mathbf{p}_{i,j}}, \quad \mathbf{\mu}^B_i = \frac{\mathbf{c}_{i,0}}{T},

where f=(0,1T,,KT)\mathbf{f} = (0, \frac{1}{T},\dots,\frac{K}{T}) is the frequencies vector. At the same time, the phase shift S\mathbf{S} is predicted using a fully-connected (FC) layer with two arctan\arctan activation as:

(sy,sx)=FC(Li),μiS=arctan(sy,sx),(s_y, s_x) = \text{FC}(\mathbf{L}_i), \quad \mathbf{\mu}^S_i = \arctan(s_y,s_x),

To predict the distributional variance of the phase amplitude and phase shift parameters {σA,σS}\{\mathbf{\sigma}^A, \mathbf{\sigma}^S\}, We additionally apply a separate two-layer MLP network over each channel of the latent curves. The variational latent phase parameters are sampled by utilizing parameterization trick, i.e., AN(μA,σA)\mathbf{A}\sim\mathcal{N}(\mathbf{\mu}^A,\mathbf{\sigma}^A) and SN(μS,σS)\mathbf{S}\sim\mathcal{N}(\mathbf{\mu}^S,\mathbf{\sigma}^S). In our experiments, we find that sampling the phase frequency F\mathbf{F} and offset B\mathbf{B} often produce unstable and non-coherent group movements. This might be because the frequency amplitudes of the dancers within the same group are likely to associate with the rhythmic pattern of the musical beats while the offsets capture their alignment, thereby should be consistent with each other. Therefore, we treat those parameters as deterministic by constraining their variance to zero.

Finally, the sampled set of phase parameters z={A;F;B;S}\mathbf{z} = \{\mathbf{A};\mathbf{F};\mathbf{B};\mathbf{S}\} are used to reconstruct a parametric latent space consisting of multiple periodic curves to represent each intrinsic property of the motion by:

L^=Asin(2π(FTS))+B\hat{\mathbf{L}} = \mathbf{A} \cdot \sin (2\pi \cdot (\mathbf{F}\cdot\mathcal{T} - \mathbf{S})) + \mathbf{B}

where T\mathcal{T} is a known time window series obtained by evenly spacing the timesteps from 00 to TT. Intuitively this curve construction procedure can be viewed as a "quantization" layer to enforce the network to learn to represent the motion features in the frequency domain, which is useful in representing different aspects of human motion such as their timing and periodicity. In the last step, a decoder is utilized to reconstruct the original motion signals from the set of parametric latent curves.

Figure 4. Manifold conduction.

Decoder

To decode the latent space into the original motion space, previous works have to use a sinusoidal positional encoding sequence with duration TT as the proxy input to the sequence decoder. This is because their latent space is only formed by single latent vectors following a Gaussian distribution, which cannot span the time dimension. However, we observe that it usually results in unstable and inconsistent movements, as the proxy sequence is generic and usually contains less meaningful information for the decoder. Meanwhile, our method does not suffer from this problem as our latent space is built on multiple curves that can represent the motion information through time, thanks to the phase parameters. Subsequently, our decoder D\mathcal{D} is based on Transformer decoder architecture that takes the constructed parametric latent curve, as well as the music features as inputs, to reconstruct the corresponding dance motions. Here, we also utilize the cross-attention model where we consider the sequence of and music features as key and value along with the sampled latent curves as the query. The output of the decoder is a sequence of TT vectors in RD\mathbb{R}^D, which is then projected back to the original motion dimensions through a linear layer, to obtain the reconstructed outputs x^=pθ(xz,a)\hat{\mathbf{x}}=p_\theta(\mathbf{x}|\mathbf{z},\mathbf{a}). We additionally employ a global trajectory predictor to predict the global translation of the root joint based on the generated local motions, in order to avoid intersection problems between dancers.

Prior Network.

Since the ground-truth motion is generally inaccessible at test time (i.e., we only have access to the music), we also need to learn a prior P\mathcal{P} to match the posterior distribution of motion from which the latent phase can be sampled. Specifically, We follow the manifold procedure to predict the Gaussian distribution conditioned on the music sequence a\mathbf{a}, which is then used for sampling the latent phases:

pθ(za)=N(z;μP,σP)p_\theta(\mathbf{z}|\mathbf{a}) = \mathcal{N}(\mathbf{z};\mathbf{\mu}_{\mathcal{P}}, \mathbf{\sigma}_{\mathcal{P}})

where a Transformer encoder is used to encode the input conditioning music sequence and predict the corresponding μP\mathbf{\mu}_{\mathcal{P}} and σP\mathbf{\sigma}_{\mathcal{P}}. We implement the prior network similarly to the encoder network, however, we use self-attention mechanism to capture the global music context. Learning the conditional prior is crucial for the conditional VAE to generalize to diverse types of music and motion. Intuitively speaking, each latent variable z\mathbf{z} is expected to represent possible dance motions x\mathbf{x} conforming to the music context a\mathbf{a}. Therefore, the prior should be able to encode different latent distributions given different musics.

Next

In the next part, we will explore the training procedure and experimental setups to validate the effectiveness of the proposed method.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 6)

In this section, we will mention about ontology, fashion ontology, and its related information and present the contributions of object ontology to the CFOR system.

1. Ontology Definition for CFOR System

As described by Guarino, ontology is a "formal, explicit specification of a shared conceptualization." Typically, ontologies consist of concepts and their hierarchical structure, aiding in organizing information within a domain. A complete ontology typically includes concepts, relations, and axioms. Additionally, ontologies offer several key advantages:

  • Describing domain knowledge through a semantic hierarchical tree, with concepts represented as nodes identified by words or phrases.
  • Bridging the semantic gap in various tasks, including those in computer vision and other disciplines.
  • Enhancing software engineering practices by improving flexibility, reliability, specification, and reusability.
  • Supporting multitask problem-solving capabilities.

Any proposed ontology should satisfy two fundamental criteria:

  • Wide recognition within the community.
  • Feasibility for formalization using mathematical expressions, enabling digitization.

In our approach, we employ ontological engineering to facilitate communication and information sharing across different levels of data abstraction involved in image fashion retrieval, detection, and information tagging.

The object ontology comprises two primary levels: coarse-grained and fine-grained.

  • At the coarse level, the object ontology includes regions, categories, or high-level conceptual entities, which leverage global features extracted by deep networks for similarity retrieval. However, these deep features are treated as black boxes, lacking explicit semantic information to aid users in their search process.
  • At the fine-grained level, the object ontology encompasses attributes that provide detailed descriptions of objects.

In our experiment, we focus on describing the object "Fashion." The fashion ontology is constructed using prior knowledge and information from the DeepFashion dataset, along with ontology definitions introduced by Guarino. See Figure.1 for an illustration of the fashion ontology.

Figure 1. Fashion ontology in general and a version of ontology for clothes.

The fashion ontology developed comprises three primary semantic levels: 1. Regions: Representing areas such as Top, Bottom, and Body. 2. Categories: Specific objects associated with each region, such as dresses or robes for the Body region. 3. Attributes: Describing detailed visual concepts like denim or fur.

To streamline the discussion, our investigation focuses on the object fashion across three regions (Top, Body, and Bottom), select categories within these regions, and their respective attributes.

Within the CFOR system, a query image undergoes processing starting from the coarse level of the object ontology to identify the region and category of the corresponding object. Subsequently, the object proceeds to the fine-grained concept ontology to ascertain attributes. Once all necessary information is obtained, the object undergoes indexing and similarity distance computation to identify similar images in the database, ranked by a cumulative score derived from similarity scores of global features and attribute learning between the query image and target database images. For a detailed illustration, refer to Figure.2.

Figure 2. An example of a relationship between the query image and semantic information from the coarse-grained level to the fine-grained level of the fashion ontology.

2. Fashion Object Ontology

In this section, we introduce the fashion object ontology. Within the fashion domain, we categorize semantic fashion concepts based on regions. Each region encompasses a detailed ontology comprising categories and attributes. To facilitate experiments using the DeepFashion dataset, we extend the fashion ontology within the "Clothes" branch (refer to Figure 9). It's essential to emphasize that the proposed ontology is not specific to any application and should be viewed as a flexible foundation.

The fashion object ontology consists of multiple levels of concepts, with relations between each level to articulate their associations. Two primary relations are employed: 1. "Part of": This relation specifies that the concepts are components of the main concept. 2. "Has a": This relation describes the main concept in detail.

For this study, we concentrate solely on the Clothes branch to ensure equitable comparisons with other methodologies. The Clothes taxonomy comprises 50 distinct categories. A clothing region taxonomy has been established (refer to Figure.3), organizing all clothing categories hierarchically. The first level of this hierarchy represents the most general clothing region, with three primary regions defined: 1. Top (e.g., tee and tank) 2. Bottom (e.g., skirt and jeans) 3. Body (e.g., dress and robe)

Figure 3. Excerpt from the “Clothes” taxonomy defined in the fashion ontology.

3. Fine-Grained Object Ontology

Fine-grained object ontology is used to describe objects at the attribute level. Semantic information such as attributes can be useful for a customer to retrieve (see Figure.4). It is important to note that the proposed ontology is not application dependent and should be considered as an extensible basis.

Figure 4. Fine-grained group at the attribute level.

Cloth attributes vary across different levels—some attributes, like color, are common across all cloth regions, while others are specific to certain regions or categories. Our ontology is structured into two main parts, each detailed in the following sections: 1. Specific fashion concepts—pertaining to particular characteristics of clothes such as fabric, part, and style. 2. Visual concepts—related to popular visual characteristics like color, shape, and texture, not exclusive to fashion.

Rudd et al. demonstrated in a study that a multitask learning-based model outperforms a combination of single-task learning-based models in face attribute prediction. While this approach shows promising results for fashion attributes as well, there's a significant difference in the quantity of attributes between faces and fashion items. This disparity can pose challenges in scaling the system, such as in training and storage requirements. To address this, we propose applying local multitask learning to attribute learning, providing more flexibility. Further explanation is provided in the subsequent sections.

Next

In the next post, we will discover Attribute Learning and its correlation with multitask learning.

Reducing Non-IID Effects in Federated Autonomous Driving with Contrastive Divergence Loss (Part 3)

In the previous article, we delved into integrating the contrastive divergence loss function into federated learning, exploring its potential benefits for enhancing model performance and tackling non-IID data issues in autonomous driving contexts. In this follow-up piece, we delve into various federated learning configurations and present experimental findings that support the efficacy of our approach.

1. Implementation

Dataset: Our experimentation involves three datasets (see Table 1): Udacity+, Gazebo Indoor, and Carla Outdoor. Gazebo and Carla datasets exhibit non-IID characteristics, while Udacity+ represents a non-IID variant of the Udacity dataset.

DatasetTotal samplesAverage samples in each silo (Gaia)Average samples in each silo (NWS)Average samples in each silo (Exodus)
Udacity+38,5863,5081,754488
Gazebo66,8066,0733,037846
Carla73,2356,6583,329927

Table 1: The Statistic of Datasets in Our Experiments.

Network Topology: Our experimentation encompasses three distinct federated topologies, namely the Internet Topology Zoo (Gaia), North American data centers (NWS), and the Zoo Exodus network (Exodus). The primary focus is on the Gaia topology, while supplementary insights from the NWS and Exodus topologies are provided in our ablation study.

Training: Within each silo, model training is executed with a batch size of 32 and a learning rate set at 0.001, facilitated by the Adam optimizer. The local training regimen within each silo precedes the transmission and aggregation of models using the specified global aggregation equation. The training regimen spans 3,600 communication rounds and leverages a simulation environment akin to that described in Nguyen et al. (2022), powered by an NVIDIA 1080 GPU.

Baselines: Our comparative analysis involves several contemporary methods across diverse learning scenarios, including Random and Constant baselines as outlined by Loquercio et al. (2018). Within the Centralized Local Learning (CLL) scenario, we utilize Inception-V3, MobileNet-V2, VGG-16, and Dronet as baseline models. In the context of Server-based Federated Learning (SFL), our comparison extends to FedAvg, FedProx, and STAR. For the Decentralized Federated Learning (DFL) setting, our evaluation includes MATCHA, MBST, and FADNet. Model effectiveness is assessed using Root Mean Square Error (RMSE) and Mean Absolute Error (MAE), while wall-clock time (ms) serves as a metric for training duration.

2. Qualitative Results

In practice, we've noticed that the initial phases of federated learning often yield subpar accumulated models. Unlike other approaches that tackle the non-IID issue by refining the accumulation step whenever silos transmit their models, we directly mitigate the impact of divergence factors during the local learning phase of each silo. Our method aims to minimize the discrepancy between the distribution of accumulated weights from neighboring silos in the backbone network (representing divergence factors) and the weights specific to silo ii in the sub-network (comprising locally learned knowledge). Once the distribution between silos achieves an acceptable level of synchronization, we reduce the influence of the sub-network and prioritize the steering angle prediction task. Inspired by the contrastive loss of the original Siamese Network, our proposed Contrastive Divergence Loss is formulated as follows:

ModelMain FocusLearning MethodRMSE (Udacity+)RMSE (Gazebo)RMSE (Carla)MAE (Udacity+)MAE (Gazebo)MAE (Carla)# Training ParametersAvg. Cycle Time (ms)
Random__0.3580.1170.4640.2650.0870.361__
ConstantStatistical_0.3110.0920.3480.2090.0670.232__
InceptionArchitecture DesignCLL0.2090.0850.2970.1970.0620.20721,787,617_
MobileNetArchitecture DesignCLL0.1930.0830.2860.1760.0570.2002,225,153_
VGG-16Architecture DesignCLL0.1900.0830.3160.1610.0500.1847,501,587_
DroNetArchitecture DesignCLL0.1830.0820.3330.1500.0530.218314,657_
FedAvgAggregation OptimizationSFL0.2120.0940.2690.1850.0640.222314,657152.4
FedProxAggregation OptimizationSFL0.1520.0770.2260.1180.0630.151314,657111.5
STARAggregation OptimizationSFL0.1790.0620.2080.1490.0530.155314,657299.9
MATCHATopology DesignDFL0.1820.0690.2080.1480.0580.215314,657171.3
MBSTTopology DesignDFL0.1830.0720.2140.1490.0580.206314,65782.1
FADNetTopology DesignDFL0.1620.0690.2030.1340.0550.197317,72962.6
CDL (ours)Loss OptimizationCLL0.1690.0740.2660.1490.0530.172629,314_
CDL (ours)Loss OptimizationSFL0.1500.0600.2080.1040.0520.150629,314102.2
CDL (ours)Loss OptimizationDFL0.1410.0620.1830.0830.0520.147629,31472.7

Table 2: Performance comparison between different methods. The Gaia topology is used.

Table above summarizes the performance comparison between our proposed method and recent state-of-the-art approaches. The results indicate that our CDL under the Siamese setup with two ResNet-8 models outperforms other methods by a significant margin. Notably, our approach achieves substantial reductions in both RMSE and MAE across all three datasets: Udacity+, Carla, and Gazebo. Despite not increasing the network's parameter count, CDL introduces a larger model size during training due to the additional sub-network required by the Siamese setup. Additionally, our CDL with ResNet-8 demonstrates superior performance compared to other baselines, particularly in the DFL learning scenario, and to a lesser extent in SFL and CLL setups.

3. Contrastive Divergence Loss Analysis

CDL Performance Across Various Topologies In practice, training federated algorithms becomes more complex as the topology involves more vehicle data silos. To assess the efficacy of our CDL, we conduct training experiments and compare the results with other baseline methods across different topologies. Table 3 presents the performance comparison of DroNet, FADNet, and our CDL with a ResNet-8 backbone when trained using DFL across three distributed network infrastructures with varying numbers of silos: Gaia (11 silos), NWS (22 silos), and Exodus (79 silos). The table clearly demonstrates that our CDL consistently achieves superior results across all topology configurations. In contrast, DroNet encounters divergence issues, and FADNet exhibits suboptimal performance, particularly in the Exodus topology with 79 silos.

TopologyArchitectureUdacity+GazeboCarla
Gaia (11 silos)DroNet0.177 (↓0.036)0.073 (↓0.011)0.244 (↓0.061)
FADNet0.162 (↓0.021)0.069 (↓0.007)0.203 (↓0.020)
CDL (ours)0.1410.0620.183
NWS (22 silos)DroNet0.183 (↓0.045)0.075 (↓0.017)0.239 (↓0.057)
FADNet0.165 (↓0.027)0.070 (↓0.012)0.200 (↓0.018)
CDL (ours)0.1380.0580.182
Exodus (79 silos)DroNet0.448 (↓0.310)0.208 (↓0.147)0.556 (↓0.380)
FADNet0.179 (↓0.041)0.081 (↓0.020)0.238 (↓0.062)
CDL (ours)0.1380.0610.176

Table 3: Performance under different topologies.

CDL with Various Architectures Our CDL, functioning as a loss function, exhibits versatility across different network architectures when integrated into the Siamese setup, leading to performance enhancements. Figure.1 showcases the efficacy of CDL across diverse networks such as DroNet, FADNet, Inception, MobileNet, and VGG-16 within the Gaia Network framework in the DFL scenario. The outcomes demonstrate CDL's efficacy in mitigating the non-IID challenge across varied architectures, consistently elevating performance.

Figure 1. Performance of CDL under different networks in Siamese setup.

CDL with IID Data Figure.2 illustrates CDL's efficacy across various data distributions. While CDL is primarily tailored for addressing the non-IID challenge, it also demonstrates marginal performance enhancements when applied to models trained on IID data distributions. Leveraging the Siamese setup, CDL inherits traits and behaviors akin to triplet loss. Given triplet loss's established effectiveness with IID data, it becomes evident that CDL can similarly augment model performance in scenarios where IID data is utilized.

Figure 2. Performance of different methods on IID dataset (Udacity) and non-IID dataset (Udacity+).

3. Ablation Study

Figure.3 illustrates the training results in RMSE of our two baselines, DroNet and FADNet, as well as our proposed CDL. The results showcase the convergence ability of the mentioned methods across three datasets (Udacity+, Gazebo, Carla) with NWS and Gaia topology. It is evident that our proposed CDL can attain a superior convergence point compared to the two baselines. While other methods (DroNet and FADNet) struggled to converge or exhibited poor convergence trends, our proposed CDL demonstrates better overcoming of local optimal points and shows less bias towards any specific silo.

Figure 3. The convergence ability of different methods under Gaia topology (top row) and NWS topology (bottom row).

4. Discussion

Although our method exhibits promising results, there are several areas for potential improvement that warrant consideration for future work:

  • Our CDL is specifically designed to address the non-IID problem, meaning its effectiveness heavily depends on the presence of non-IID characteristics within the autonomous driving data. Thus, in scenarios where data distributions across vehicles are relatively consistent or lack significant non-IID factors, the proposed contrastive divergence loss may only lead to limited performance enhancements.

  • While our proposal has been validated on autonomous driving datasets, real-world testing on actual vehicles has not yet been conducted. Conducting a study involving various driving scenarios, such as interactions with pedestrians, cyclists, and other vehicles, could provide further validation of our method's efficacy.

  • Although our approach is tailored for autonomous driving applications, its underlying principles may have applicability in other domains where non-IID data is prevalent. Exploring its effectiveness in areas like healthcare, IoT, and industrial settings could expand its potential impact.

Conclusion

We presented a new method to address the non-IID problem in federated autonomous driving using contrastive divergence loss. Our method directly reduces the effect of divergence factors in the learning process of each silo. The experiments on three benchmarking datasets demonstrate that our proposed method performs substantially better than current state-of-the-art approaches. In the future, we plan to test our strategy with more data silos and deploy the trained model using an autonomous vehicle on roads.

Scalable Group Choreography via Variational Phase Manifold Learning (Part 1)

Generating group dance motion from the music is a challenging task with several industrial applications. Although several methods have been proposed to tackle this problem, most of them prioritize optimizing the fidelity in dancing movement, constrained by predetermined dancer counts in datasets. This limitation impedes adaptability to real-world applications. Our study addresses the scalability problem in group choreography while preserving naturalness and synchronization.

Figure 1. We present a new group dance generation method that can generate a large number of dancers within a fixed resource consumption. The illustration shows a generated group dance sample with 100100 dancers.

In particular, we propose a phase-based variational generative model for group dance generation on learning a generative manifold. Our method achieves high-fidelity group dance motion and enables the generation with an unlimited number of dancers while consuming only a minimal and constant amount of memory. The intensive experiments on two public datasets show that our proposed method outperforms recent state-of-the-art approaches by a large margin and is scalable to a great number of dancers beyond the training data.

Introduction

The proliferation of digital social media platforms has significantly increased the popularity of creating and sharing dance videos. This surge in interest has led to the daily production and consumption of millions of dance videos across various online platforms, attracting attention from the research community in fields such as computer vision and computer graphics. Recent advancements in these areas have focused on generating authentic dance movements in response to music, with broad applications spanning animation, virtual idols, virtual metaverses, and dance education. These technologies provide powerful tools for artists, animators, and educators, enhancing creativity and enriching the dance experience for both performers and audiences.

Figure 2. An example visualization of a virtual group dancer.
While considerable progress has been made in generating solo dance motions, creating synchronized and realistic group dance performances remains a complex and unresolved challenge. Existing methods typically face limitations in scalability, either generating dances for a fixed number of performers or suffering from high memory consumption due to architectural constraints. These approaches, often based on collaborative mechanisms such as cross-entity or global attention, struggle to scale up for larger groups of dancers, limiting their practical applicability. Moreover, the reliance on predefined datasets with a fixed number of dancers further restricts these models from being adapted to real-world scenarios requiring larger group choreographies.

To address these challenges, we propose a novel approach to scalable group dance generation using a phase-based variational generative model, termed Phase-conditioned Dance VAE (PDVAE). Unlike traditional variational autoencoders that operate in high-dimensional motion space and often struggle to capture temporal dynamics, PDVAE leverages phase parameters in the frequency domain to represent the latent motion space. This enables the generation of realistic and synchronized group dance performances without increasing computational and memory costs, even as the number of dancers grows. PDVAE provides a flexible and efficient solution for generating crowd-scale dance animations, with potential applications in diverse fields such as entertainment, virtual reality, education, and media production.

In this work, we aim to advance the state-of-the-art in scalable group dance generation by overcoming the limitations of existing methods and demonstrating the feasibility of generating large-scale, natural, and synchronized dance performances efficiently.

To summarize, our key contributions are as follows: - We introduce PDVAE, a phase-based variational generative model for scalable group dance generation. The method focuses on generating large-scale group dance under limited resources. - To effectively learn the manifold that is group-consistent (i.e., dancers within a group lie upon the same manifold), we propose a group consistency loss that enforces the networks to encode the latent phase manifold to be identical for the same group given the input music. - Extensive experiments along with thorough user study evaluations demonstrate the state-of-the-art performance of our model while achieving effective scalability.

Related Works

Music-driven Choreography

Crafting natural human choreography derived from music poses a complex challenge, encompassing the need for synchronization, coherence, and expressiveness between movement and musical inputs. Earlier approaches often relied on music-motion similarity constraints to ensure alignment between the generated dance and the accompanying music. Many of these methods employ heuristic algorithms to stitch together pre-existing dance segments from limited music-dance databases, successfully producing extended and realistic dance sequences. However, such techniques are constrained when attempting to generate novel dance fragments, as they depend heavily on pre-defined segments rather than innovative motion generation .

Recent advancements have focused on utilizing deep learning architectures to map music into dance motions. Various techniques, including Convolutional Networks (CNN) , Recurrent Networks (RNN) , Graph Neural Networks (GNN) , Generative Adversarial Networks (GAN) , and Transformer models , have been explored for dance generation. These models typically rely on inputs such as the current music and a brief history of previous dance motions to predict the next human poses in a sequence. For instance, innovations in multi-modal feature fusion have enabled the simultaneous integration of music and text, producing dance sequences guided by both musical and textual cues .

Despite the progress made by these methods, generating synchronized and coordinated dance movements for multiple dancers remains a significant challenge. Achieving harmony between multiple dancers requires not only temporal alignment but also the consideration of spatial relationships and interactions between performers. This adds complexity to the generation process, making it difficult for current techniques to manage multiple dancers cohesively. Additionally, these methods are often constrained by the limited number of dancers present in their training datasets .

Figure 3. Multi-person Tracking with GNN.

Several recent works have sought to address these limitations. For instance, Perez et al. propose a multimodal transformer combined with a normalizing-flow-based decoder to predict a distribution of possible future poses, offering improved flexibility in motion generation. Feng et al. introduce a motion repeat constraint for long-term generation, allowing their model to generate future frames while taking into account historical dance motions. Le et al. further explore group dance by examining consistency and diversity among dancers, ensuring that generated movements maintain coherence across multiple performers. However, these methods are still limited by the number of dancers depicted in training datasets, restricting their scalability for larger group performances.

Figure 4. Local Mesh Fitting process for conducting a dance dataset.

Overall, while significant strides have been made in music-driven dance generation, further research is needed to overcome the challenges of scalability and synchronization in group dance synthesis.

Motion Manifold Learning

Motion manifold learning has garnered significant attention in computer vision and artificial intelligence, primarily aiming to understand the fundamental structures underlying human movement and dynamics. This approach offers the capability to generate human motion patterns, providing insights into intrinsic motion dynamics, managing nonlinear relationships within motion data, and learning contextual and hierarchical representations. Consequently, numerous methodologies have emerged, each contributing distinct perspectives to advance the comprehension and synthesis of human motion.

Holden et al. pioneered motion manifold learning by generating character movements from high-level parameters mapped to a motion manifold, eliminating the need for manual preprocessing while enabling natural, smooth post-generation editing of motion sequences. MotionCLIP introduced a 3D human motion autoencoder aligned with CLIP's semantic space, facilitating semantic text-based motion generation, disentangled editing, and abstract language specification. This approach capitalizes on CLIP's rich semantic knowledge, integrating it within the motion manifold for enhanced control and interpretation of human movement. Sun et al. further advanced this field by employing VQ-VAE to learn a low-dimensional motion manifold, effectively refining motion sequences to improve coherence and continuity.

Figure 5. Manifold conduction.
In the context of group dance generation, motion manifold learning presents a promising solution to scalability challenges, particularly the restricted number of dancers present in most datasets. By learning a distribution over dance motions within the manifold, this approach enables the generation of synchronized, cohesive group dance sequences, potentially overcoming dataset limitations. This direction of research highlights the potential of manifold-based methods to enhance scalability, allowing for the synthesis of large-scale, realistic group dance performances that maintain temporal and spatial harmony.

Next

In the next part, we will dive deeply into the main proposal that leveragw manifold learning to handle scalability of group dance generation.

Reducing Non-IID Effects in Federated Autonomous Driving with Contrastive Divergence Loss (Part 2)

In the previous post, we delved into the realm of federated learning in the context of autonomous driving, exploring its potential and the challenges posed by non-IID data distribution. We also discussed how contrastive divergence offers a promising avenue to tackle the non-IID problem in federated learning setups. Building upon that foundation, in this post, we will delve deeper into the integration of contrastive divergence loss into federated learning frameworks. We'll explore the mechanics of incorporating this loss function into the federated learning process and examine its potential implications for improving model performance and addressing non-IID data challenges in autonomous driving scenarios. We unravel the intricacies of leveraging contrastive divergence within federated learning paradigms to advance the capabilities of autonomous driving systems.

1. Overview

Motivation: The effectiveness of federated learning algorithms in autonomous driving hinges on two critical factors: firstly, the ability of each local silo to glean meaningful insights from its own data, and secondly, the synchronization among neighboring silos to mitigate the impact of the non-IID problem. Recent efforts have primarily focused on addressing these challenges through various means, including optimizing accumulation processes and optimizers, proposing novel network topologies, or leveraging robust deep networks capable of handling the distributed nature of the data. However, as highlighted by Duan et al., the indiscriminate adoption of high-performance deep architectures and their associated optimizations in centralized local learning scenarios can lead to increased weight variance among local silos during the accumulation process in federated setups. This variance detrimentally impacts model convergence and may even induce divergence, underscoring the need for nuanced approaches to ensure the efficacy of federated learning in autonomous driving contexts.

Figure 1. The Siamese setup when our CDL is applied for training federated autonomous driving model. ResNet-8 is used in the backbone and sub-network in the Siamese setup. During inference, the sub-network will be removed. Dotted lines represent the backward process. Our CDL has two components: the positive contrastive divergence loss Lcd+\mathcal{L}_{\rm {cd^+}} and the negative regularize term Lcd\mathcal{L}_{\rm {cd^-}}. The local regression loss Llr\mathcal{L}_{\rm {lr}} for automatic steering prediction is calculated only from the backbone network.

Siamese Network Approach: In our study, we propose a novel approach to directly tackle the non-IID problem within each local silo by addressing the challenges of learning optimal features and achieving synchronization separately. Our strategy involves implementing \textit{two distinct networks within each silo}: one network is dedicated to extracting meaningful features from local image data, while the other focuses on minimizing the distribution gap between the current model weights and those of neighboring silos. To facilitate this, we employ a Siamese Network architecture, comprising two branches. The first branch, serving as the backbone network, is tasked with learning local image features for autonomous steering using a local regression loss Llr\mathcal{L}_{\rm {lr}}, while simultaneously incorporating a positive contrastive divergence loss Lcd+\mathcal{L}_{\rm {cd^+}} to assimilate knowledge from neighboring silos. Meanwhile, the second branch, referred to as the sub-network, functions to regulate divergence factors arising from the backbone's knowledge through a contrastive regularizer term Lcd\mathcal{L}_{\rm {cd^-}}. See Figure.1 for more detail.

In practice, the sub-network initially adopts the same weights as the backbone during the initial communication round. However, starting from the subsequent communication rounds, once the backbone undergoes accumulation using Equation below, each silo's local model is trained using the contrastive divergence loss. The sub-network produces auxiliary features of identical dimensions to the output features of the backbone. Throughout training, we anticipate minimal discrepancies in weights between the backbone and the sub-network when employing the contrastive divergence loss. Synchronization of weights across all silos occurs when gradients from the backbone and sub-network learning processes exhibit minimal disparity.

θi(k+1)=θi(k)αk1mh=1mLlr(θi(k),ξih(k))\theta_i\left(k + 1\right) = {\theta}_i\left(k\right)-\alpha_{k}\frac{1}{m}\sum^m_{h=1}\nabla \mathcal{L}_{\rm {lr}}\left({\theta}_i\left(k\right),\xi_i^h\left(k\right)\right)

where Llr\mathcal{L}_{\rm {lr}} is the local regression loss for autonomous steering.

2. Contrastive Divergence Loss

In practice, we've noticed that the initial phases of federated learning often yield subpar accumulated models. Unlike other approaches that tackle the non-IID issue by refining the accumulation step whenever silos transmit their models, we directly mitigate the impact of divergence factors during the local learning phase of each silo. Our method aims to minimize the discrepancy between the distribution of accumulated weights from neighboring silos in the backbone network (representing divergence factors) and the weights specific to silo ii in the sub-network (comprising locally learned knowledge). Once the distribution between silos achieves an acceptable level of synchronization, we reduce the influence of the sub-network and prioritize the steering angle prediction task. Inspired by the contrastive loss of the original Siamese Network, our proposed Contrastive Divergence Loss is formulated as follows:

Lcd=βLcd++(1β)Lcd=βH(θib,θis)+(1β)H(θis,θib)\mathcal{L}_{\rm {cd}} = \beta \mathcal{L}_{\rm {cd^+}} + (1-\beta) \mathcal{L}_{\rm {cd^-}} = \beta \mathcal{H}(\theta^b_i, \theta^s_i) + (1-\beta) \mathcal{H}(\theta^s_i,\theta^b_i)

where Lcd+\mathcal{L}_{\rm {cd^+}} is the positive contrastive divergence term and Lcd\mathcal{L}_{\rm {cd^-}} is the negative regularizer term; H\mathcal{H} is the Kullback-Leibler Divergence loss function

H(y^,y)=f(y^)log(f(y^)f(y))\mathcal{H}(\hat{y},y) = \sum \mathbf{f}(\hat{y}) \log\left(\frac{\mathbf{f}(\hat{y})}{\mathbf{f}(y)}\right)

where y^\hat{y} is the predicted representation, yy is dynamic soft label.

Consider Lcd+\mathcal{L}_{\rm {cd^+}} in Equation above as a Bayesian statistical inference task, our goal is to estimate the model parameters θb\theta^{b*} by minimizing the Kullback-Leibler divergence H(θib,θis)\mathcal{H}(\theta^b_i, \theta^s_i) between the measured regression probability distribution of the observed local silo P0(xθis)P_0 (x|\theta^s_i) and the accumulated model P(xθib)P (x|\theta^b_i). Hence, we can assume that the model distribution has a form of P(xθib)=eE(x,θib)/Z(θib)P (x|\theta^b_i) = e^{-E(x,\theta^b_i)}/Z(\theta^b_i), where Z(θib)Z(\theta^b_i) is the normalization term. However, evaluating the normalization term Z(θib)Z(\theta^b_i) is not trivial, which leads to risks of getting stuck in a local minimum. Inspired by Hinton, we use samples obtained through a Markov Chain Monte Carlo (MCMC) procedure with a specific initialization strategy to deal with the mentioned problem. Additionally inferred from Equation above, the Lcd+\mathcal{L}_{\rm {cd^+}} can be expressed under the SGD algorithm in a local silo by setting:

Lcd+=xP0(xθis)E(x;θib)θib+xQθib(xθis)E(x;θib)θib\mathcal{L}_{\rm {cd^+}} = -\sum_{x}P_0 (x|\theta^s_i)\frac{\partial E(x;\theta^b_i)}{\partial \theta^b_i} + \sum_{x}Q_{\theta^b_i} (x|\theta^s_i)\frac{\partial E(x;\theta^b_i)}{\partial \theta^b_i}

where Qθib(xθis)Q_{\theta^b_i} (x|\theta^s_i) is the measured probability distribution on the samples obtained by initializing the chain at P0(xθis)P_0 (x|\theta^s_i) and running the Markov chain forward for a defined step.

Consider Lcd\mathcal{L}_{\rm {cd^-}} regularizer in Equation above as a Bayesian statistical inference task, we can calculate Lcd\mathcal{L}_{\rm {cd^-}} as in Equation above, however, the role of θs\theta^s and θb\theta^b is inverse:

Lcd=xP0(xθib)E(x;θis)θis+xQθis(xθib)E(x;θis)θis\mathcal{L}_{\rm {cd^-}}=-\sum_{x}P_0 (x|\theta^b_i)\frac{\partial E(x;\theta^s_i)}{\partial \theta^s_i} + \sum_{x}Q_{\theta^s_i} (x|\theta^b_i)\frac{\partial E(x;\theta^s_i)}{\partial \theta^s_i}

We note that the key difference is that while the weight θib\theta^b_i of the backbone is updated by the accumulation process from Equation above, the weight θis\theta^s_i of the sub-network, instead, is not. This lead to different convergence behavior of contrastive divergence in Lcd+\mathcal{L}_{\rm {cd^+}} and Lcd\mathcal{L}_{\rm {cd^-}}. The negative regularizer term Lcd\mathcal{L}_{\rm {cd^-}} will converge to state θis\theta^{s*}_i provided Eθis\frac{\partial E}{\partial \theta^s_i} is bounded:

g(x,θis)=E(x;θis)θisxP0(x(θib,θis))E(x;θis)θisg(x,\theta^s_i) = \frac{\partial E(x;\theta^s_i)}{\partial \theta^s_i} -\sum_{x}P_0 (x|(\theta^b_i,\theta^s_i))\frac{\partial E(x;\theta^s_i)}{\partial \theta^s_i}
(θisθis){xP0(x)g(x,θis)x,xP0(x)Kθism(x,x)g(x,θis)}k1θisθis2(\theta^s_i - \theta^{s*}_i)\cdot\left\{ \sum_{x}{P_0(x)g(x,\theta^s_i)} - \sum_{x',x}{P_0(x')\mathbf{K}^m_{\theta^s_i}(x',x)g(x,\theta^{s*}_i})\right\}\geq \mathbf{k}_1|\theta^s_i-\theta^{s*}_i|^2

for any k1\mathbf{k}_1 constraint. Note that, Kθsm\mathbf{K}^m_{\theta^s} is the transition kernel.

Note that the negative regularizer term Lcd\mathcal{L}_{\rm {cd^-}} is only used in training models on local silos. Thus, it does not contribute to the accumulation process of federated training.

3. Total Training Loss

Local Regression Loss. We use mean square error (MAE) to compute loss for predicting the steering angle in each local silo. Note that, we only use features from the backbone for predicting steering angles.

Llr=MAE(θib,ξ^i)\mathcal{L}_{\rm {lr}} = \text{MAE}(\theta^b_i, \hat{\xi}_i )

where ξ^i\hat{\xi}_i is the ground-truth steering angle of the data sample ξi\xi_i collected from silo ii.

Local Silo Loss. The local silo loss computed in each communication round at each silo before applying the accumulation process is described as:

Lfinal=Llr+Lcd\mathcal{L}_{{\rm{final}}} = \mathcal{L}_{\rm {lr}} + \mathcal{L}_{\rm {cd}}

In practice, we observe that both the contrastive divergence loss Lcd\mathcal{L}_{\rm {cd}} to handle the non-IID problem and the local regression loss Llr\mathcal{L}_{\rm {lr}} for predicting the steering angle is equally important and indispensable.

Combining all losses together, at each iteration kk, the update in the backbone network is defined as:

θib(k+1)={jNi+{i}Ai,jθjb(k), if k0(modu+1),θib(k)αk1mh=1mLb(θib(k),ξih(k)),otherwise.\theta^b_i\left(k + 1\right) =\begin{cases} \sum_{j \in \mathcal{N}_i^{+} \cup{\{i\}}}\textbf{A}_{i,j}{\theta}^b_{j}\left(k\right), \textit{ \quad \quad \quad if k} \equiv 0 \pmod{u + 1},\\ {\theta}^b_i\left(k\right)-\alpha_{k}\frac{1}{m}\sum^m_{h=1}\nabla \mathcal{L}_{\rm {b}}\left({\theta}^b_i\left(k\right),\xi_i^h\left(k\right)\right), \text{otherwise.} \end{cases}

where Lb=Llr+Lcd+\mathcal{L}_{\rm {b}} = \mathcal{L}_{\rm {lr}} + \mathcal{L}_{\rm {cd^+}}, uu is the number of local updates.

In parallel, the update in the sub-network at each iteration kk is described as:

θis(k+1)=θis(k)αk1mh=1mLcd(θis(k),ξih(k))\theta^s_i\left(k + 1\right) ={\theta}^s_i\left(k\right)-\alpha_{k}\frac{1}{m}\sum^m_{h=1}\nabla \mathcal{L}_{\rm {cd^-}}\left({\theta}^s_i\left(k\right),\xi_i^h\left(k\right)\right)

Next

In the next post, we will evaluate the effectiveness of Constrative Divergence Loss in dealing with Non-IID problem in Federated Autonomous Driving.

Reducing Non-IID Effects in Federated Autonomous Driving with Contrastive Divergence Loss (Part 1)

Federated learning has been widely applied in autonomous driving since it enables training a learning model among vehicles without sharing users' data. However, data from autonomous vehicles usually suffer from the non-independent-and-identically-distributed (non-IID) problem, which may cause negative effects on the convergence of the learning process. In this paper, we propose a new contrastive divergence loss to address the non-IID problem in autonomous driving by reducing the impact of divergence factors from transmitted models during the local learning process of each silo. We also analyze the effects of contrastive divergence in various autonomous driving scenarios, under multiple network infrastructures, and with different centralized/distributed learning schemes. Our intensive experiments on three datasets demonstrate that our proposed contrastive divergence loss significantly improves the performance over current state-of-the-art approaches.

1. Introduction

Autonomous driving represents a burgeoning field wherein vehicles navigate without human intervention, employing a blend of vision, learning, and control algorithms to perceive and react to environmental changes. While traditional approaches heavily rely on supervised learning and data collection for model training, recent efforts, such as those by Bergamini et al., Chen et al., and Wang et al., have proposed various solutions to different challenges in autonomous driving. However, data collection poses privacy concerns as user data is shared with third parties, prompting a shift towards federated learning (FL). FL allows multiple parties to collaboratively train models without sharing raw data, thereby preserving user privacy while enabling autonomous vehicles to collectively learn predictive models from diverse datasets.

Typically, there are two primary federated learning scenarios: Server-based Federated Learning (SFL) and Decentralized Federated Learning (DFL). In SFL, a central node coordinates the training process and aggregates contributions from all clients, enhancing data privacy by transmitting only model weights. However, the reliance on central nodes in SFL can potentially bottleneck the system due to data transmission constraints. In contrast, DFL operates without a central server, employing a fully distributed network architecture. In autonomous driving, several works have explored both DFL and SFL to address different problems such as collision avoidance, trajectory prediction, and steering prediction.

Figure 1. Sample viewpoints over three vehicles and their steering angle distributions in Carla dataset. The visualization shows that the three vehicles have differences in visual input as well as steering angle distribution.

In practical application, both Server-based Federated Learning (SFL) and Decentralized Federated Learning (DFL) approaches exhibit their own advantages and limitations, yet both are susceptible to the non-IID problem inherent in federated learning. The non-IID problem, as described by Sattler et al. and Wang et al., arises when data partitioning across silos exhibits significant distribution shifts. Particularly in autonomous driving, this issue poses significant challenges during the accumulation process across vehicle silos. Each vehicle's unique driving patterns, weather conditions, and road types contribute to differences in data distribution, exacerbating the non-IID problem. For instance, data collected from vehicles on highways may differ substantially from data collected in urban settings. This discrepancy can lead to difficulties in constructing robust and accurate learning models, where reliance on data from a specific context may result in poor performance in others.

Previous studies have tackled the non-IID problem through various approaches such as optimizing the accumulation step, incorporating normalization into the global model, fine-tuning global model weights via distillation, or aligning the distribution between the global model and local ones. These methods typically refine global model weights to adapt to divergence factors arising from diverse distributions across local datasets. However, this optimization process can be challenging due to factors like determining the optimal accumulation step size for each local silo, managing network topology disconnections and bottlenecks, and ensuring consistency between local and global objectives.

In the paper, we introduce a novel Contrastive Divergence Loss (CDL) to tackle the non-IID problem in autonomous driving. Unlike other methods that wait for local model learning before addressing the non-IID issue during global aggregation, our CDL loss directly mitigates the impact of divergence factors throughout each local silo's learning process. This approach simplifies the adaptation of distribution between neighboring silos compared to handling the non-IID problem in the global model. Consequently, each local model becomes more resilient to changes in data distribution, allowing the use of typical accumulation methods like FedAvg for global weights without concerns about non-IID effects on convergence. The intensive experiments on three autonomous driving datasets verify our observation and show significant improvements over state-of-the-art methods.

2. Related Works

Autonomous Driving: Autonomous driving has emerged as a prominent research area in recent years, with a focus on leveraging deep learning for various tasks such as object detection, trajectory prediction, and autonomous control. For instance, Xin et al. proposed a recursive backstepping steering controller, while Xiong et al. analyzed nonlinear dynamics behavior using proportional control laws. Yi et al. presented an algorithm for self-reconfigurable robots during waypoint navigation, and Yin et al. combined model predictive control with covariance steering theory for robust autonomous driving systems.

Federated Learning for Autonomous Driving: Federated learning offers a privacy-aware solution for collaborative machine learning without sharing local data. It has gained traction in autonomous driving and intelligent transport systems, addressing communication, computation, and storage efficiency. Various works have explored federated learning for autonomous driving tasks such as steering angle prediction and turning signal prediction. However, challenges remain, including non-IID data distribution among participants, which recent works primarily address through accumulation processes rather than local silo optimization.

Contrastive Divergence: Contrastive models have gained attention in federated learning for handling heterogeneity in local data distribution. While contrastive learning has been applied to various datasets, its potential for mitigating non-IID data effects in federated autonomous driving remains underexplored. Most existing works focus on optimizing frameworks in server-based or P2P federated learning, or adjusting the accumulation process, with limited consideration of contrastive loss function behavior in federated autonomous driving scenarios.

Figure 2. A federated autonomus driving system.

3. Preliminary

We summarize the notations of our paper in Table.1.

Table 1. Notations.

We consider each autonomous vehicle as a data silo. Our goal is to collaboratively train a global driving policy θ\theta from NN silos by aggregating all local learnable weights θi\theta_i of each silo. Each silo computes the current gradient of its local loss function and then updates the model parameter using an optimizer. Mathematically, in the local update stage, at each silo ii, in each iteration kk, the model weights of the local silo can be computed as:

θi(k+1)=θi(k)αk1mh=1mLlr(θi(k),ξih(k))\theta_i\left(k + 1\right) = {\theta}_i\left(k\right)-\alpha_{k}\frac{1}{m}\sum^m_{h=1}\nabla \mathcal{L}_{\rm {lr}}\left({\theta}_i\left(k\right),\xi_i^h\left(k\right)\right)

where Llr\mathcal{L}_{\rm {lr}} is the local regression loss for autonomous steering. To update the global model, each silo interacts with the associated ones through a predefined topology:

θ(k+1)=i=0N1ϑiθi(k)\theta\left(k + 1\right) = \sum^{N-1}_{i = 0 }\vartheta_i{\theta}_i\left(k\right)

In practice, the local model in each silo is a deep network that takes the RGB images as inputs and predicts the associated steering angles

Next

In the next post, we will introduce our proposal with contrastive divergence loss.

Fine-Grained Visual Classification using Self-Assessment Classifier (Part 2)

In previous part, we have discussed about the proposal to deal with fine-grained image classification. In this part, we will verify the effectiveness and efficiency of the proposal.

1. Experimental Setup

Dataset. We evaluate our method on three popular fine-grained datasets: CUB-200-2011, Stanford Dogs and FGVC Aircraft (See Table 1).

DatasetTarget# Cate# Train# Test
CUB-200-2011Bird2005,9945,794
Stanford DogsDog12012,0008,580
FGVC-AircraftAircraft1006,6673,333

Table 1: Fine-grained classification datasets in our experiments.

Implementation. All experiments are conducted on an NVIDIA Titan V GPU with 12GB RAM. The model is trained using Stochastic Gradient Descent with a momentum of 0.9. The maximum number of epochs is set at 80; the weight decay equals 0.00001, and the mini-batch size is 12. Besides, the initial learning rate is set to 0.001, with exponential decay of 0.9 after every two epochs. Based on validation results, the number of top-k ambiguity classes is set to 10, while the parameters dϕd_{\phi}, α\alpha are set to 0.10.1 and 0.50.5, respectively.

Baseline. To validate the effectiveness and generalization of our method, we integrate it into 7 different deep networks, including two popular Deep CNN backbones, Inception-V3 and ResNet-50; and five fine-grained classification methods: WS, DT, WS_DAN, MMAL, and the recent transformer work ViT. It is worth noting that we only add our Self Assessment Classifier into these works, other setups and hyper-parameters for training are kept unchanged when we compare with original codes.

2. Experimental Results

Table.2 summarises the contribution of our Self Assessment Classifier (SAC) to the fine-grained classification results of different methods on three datasets CUB-200-2011, Stanford Dogs, and FGVC Aircraft. This table clearly shows that by integrating SAC into different classifiers, the fine-grained classification results are consistently improved. In particular, we observe an average improvement of +1.3+1.3, +1.2+1.2, and +1.2+1.2 in the CUB-200-2011, Stanford Dogs, and FGVC Aircraft datasets, respectively.

MethodsCUB-200-2011Stanford DogsFGVC Aircraft
MAMC86.585.2_
PC86.983.889.2
MC87.3_92.9
DCL87.8_93.0
ACNet88.1_92.4
DF-GMM88.8_93.8
API-Net90.090.393.9
GHORD89.6_94.3
CAL90.6_94.2
Parts Models90.493.9_
ViT + DCAL91.4_91.5
P2P-Net90.2_94.2
Inception-V383.785.187.4
Inception-V3+ SAC85.3 (+1.6)86.8 (+1.7)89.2 (+1.8)
ResNet-5086.486.190.3
ResNet-50+ SAC88.3 (+1.9)87.4 (+1.3)92.1 (+1.8)
WS88.891.492.3
WS+SAC89.9 (+1.1)92.5 (+1.1)93.2 (+0.9)
DT89.288.090.7
DT+SAC90.1 (+0.9)88.8 (+0.8)91.9 (+1.2)
MMAL89.690.694.7
MMAL+SAC90.8 (+1.2)91.6 (+1.0)95.5 (+0.8)
WS_DAN89.492.293.0
WS_DAN+SAC91.1 (+1.7)93.1 (+0.9)93.9 (+0.9)
ViT91.093.292.1
ViT+SAC91.8 (+0.8)94.5 (+1.3)93.1 (+1.0)
Avg. Improvement+1.3+1.2+1.2

Table: Contribution (% Acc) of our Self Assessment Classifier (SAC) on fine-grained classification results.

3. Qualitative Results

Attention Maps. Figure.1 illustrates the visualization of attention maps between image feature maps and each ambiguity class. The visualization indicates that by employing our Self Assessment Classifier, each fine-grained class focuses on different informative regions.

Figure 1. The visualization of the attention map between image feature maps and different ambiguity classes from our method. The red-colored class label denotes that the prediction is matched with the ground-truth.

Prediction Results. Figure.2 illustrates the classification results and corresponding localization areas of different methods. In all samples, we can see that our SAC focuses on different areas based on different hard-to-distinguish classes. Thus, the method can focus on more meaningful areas and also ignore unnecessary ones. Hence, SAC achieves good predictions even with challenging cases.

Figure 2. Qualitative comparison of different classification methods. (a) Input image and its corresponding ground-truth label, (b) ResNet-50, (c) WS_DAN, (d) MMAL, and (e) Our SAC. Boxes are localization areas. Red color indicates wrong classification result. Blue color indicates correct predicted label.

Conclusion

We introduce a Self Assessment Classifier (SAC) which effectively learns the discriminative features in the image and resolves the ambiguity from the top-k prediction classes. Our method generates the attention map and uses this map to dynamically erase unnecessary regions during the training. The intensive experiments on CUB-200-2011, Stanford Dogs, and FGVC Aircraft datasets show that our proposed method can be easily integrated into different fine-grained classifiers and clearly improve their accuracy.

Fine-Grained Visual Classification using Self-Assessment Classifier (Part 1)

Extracting discriminative features plays a crucial role in the fine-grained visual classification task. Most of the existing methods focus on developing attention or augmentation mechanisms to achieve this goal. However, addressing the ambiguity in the top-k prediction classes is not fully investigated. In this paper, we introduce a Self Assessment Classifier, which simultaneously leverages the representation of the image and top-k prediction classes to reassess the classification results. Our method is inspired by self-supervised learning with coarse-grained and fine-grained classifiers to increase the discrimination of features in the backbone and produce attention maps of informative areas on the image. In practice, our method works as an auxiliary branch and can be easily integrated into different architectures. We show that by effectively addressing the ambiguity in the top-k prediction classes, our method achieves new state-of-the-art results on CUB200-2011, Stanford Dog, and FGVC Aircraft datasets. Furthermore, our method also consistently improves the accuracy of different existing fine-grained classifiers with a unified setup.

1. Introduction

The task of fine-grained visual classification involves categorizing images that belong to the same class (e.g., various species of birds, types of aircraft, or different varieties of flowers). Compared to standard image classification tasks, fine-grained classification poses greater challenges due to three primary factors: (i) significant intra-class variation, where objects within the same category exhibit diverse poses and viewpoints; (ii) subtle inter-class distinctions, where objects from different categories may appear very similar except for minor differences, such as the color patterns of a bird's head often determining its fine-grained classification; and (iii) constraints on training data availability, as annotating fine-grained categories typically demands specialized expertise and considerable annotation effort. Consequently, achieving accurate classification results solely with state-of-the-art CNN models like VGG is nontrivial.

Recent research demonstrates that a crucial strategy for fine-grained classification involves identifying informative regions across various parts of objects and extracting distinguishing features. A common approach to achieving this is by learning the object's parts through human annotations. However, annotating fine-grained regions is labor-intensive, rendering this method impractical. Some advancements have explored unsupervised or weakly-supervised learning techniques to identify informative object parts or region of interest bounding boxes. While these methods offer potential solutions to circumvent manual labeling of fine-grained regions, they come with limitations such as reduced accuracy, high computational costs during training or inference, and challenges in accurately detecting distinct bounding boxes.

n this paper, we introduce the Self Assessment Classifier (SAC) method to tackle the inherent ambiguity present in fine-grained classification tasks. Essentially, our approach is devised to reevaluate the top-k prediction outcomes and filter out uninformative regions within the input image. This serves to mitigate inter-class ambiguity and enables the backbone network to learn more discerning features. Throughout training, our method generates attention maps that highlight informative regions within the input image. By integrating this method into a backbone network, we aim to reduce misclassifications among top-k ambiguous classes. It's important to note that "ambiguity classes" refer to instances where uncertainty in prediction can lead to incorrect classifications. Our contributions can be succinctly outlined as follows:

  • We propose a novel self-class assessment method that simultaneously learns discriminative features and addresses ambiguity issues in fine-grained visual classification tasks.
  • We demonstrate the versatility of our method by showcasing its seamless integration into various fine-grained classifiers, resulting in improved state-of-the-art performance.

Figure 1. Comparison between generic classification and fine-grained classification.

2. Method Overview

We propose two main steps in our method: Top-k Coarse-grained Class Search (TCCS) and Self Assessment Classifier (SAC). TCCS works as a coarse-grained classifier to extract visual features from the backbone. The Self Assessment Classifier works as a fine-grained classifier to reassess the ambiguity classes and eliminate the non-informative regions. Our SAC has four modules: the Top-k Class Embedding module aims to encode the information of the ambiguity class; the Joint Embedding module aims to jointly learn the coarse-grained features and top-k ambiguity classes; the Self Assessment module is designed to differentiate between ambiguity classes; and finally, the Dropping module is a data augmentation method, designed to erase unnecessary inter-class similar regions out of the input image. Figure.2 shows an overview of our approach.

Figure 2. Method Overview.

3. Top-k Coarse-grained Class Search

The TCCS takes an image as input. Each input image is passed through a Deep CNN to extract feature map FRdf×m×n\textbf{\textit{F}} \in \mathbb{R}^{d_f \times m \times n} and the visual feature VRdv\textbf{\textit{V}} \in \mathbb{R}^{d_v}. m,nm, n, and dfd_f represent the feature map height, width, and the number of channels, respectively; dvd_v denotes the dimension of the visual feature V\textbf{\textit{V}}. In practice, the visual feature V\textbf{\textit{V}} is usually obtained by applying some fully connected layers after the convolutional feature map F\textbf{\textit{F}}.

The visual features V\textbf{\textit{V}} is used by the 1st1^{st} classifier, i.e., the original classifier of the backbone, to obtain the top-k prediction results. Assuming that the fine-grained dataset has NN classes. The top-k prediction results Ck={C1,...,Ck}C_k = \{C_1,..., C_k\} is a subset of all prediction classes CNC_N, with kk is the number of candidates that have the kk-highest confident scores.

4. Self Assessment Classifier

Our Self Assessment Classifier takes the image feature F\textbf{\textit{F}} and top-k prediction CkC_k from TCCS as the input to reassess the fine-grained classification results.

Top-k Class Embedding

The output of the TCCS module CkC_k is passed through the top-k class embedding module to output label embedding set Ek={E1,...Ei,...,Ek},i{1,2,...,k},EiRde\textbf{E}_k = \{E_1,...E_i,..., E_k\}, i \in \{1,2, ..., k\}, E_i \in \mathbb{R}^{d_{e}}. This module contains a word embedding layer~\cite{pennington2014glove} for encoding each word in class labels and a GRU~\cite{2014ChoGRU} layer for learning the temporal information in class label names. ded_{e} represents the dimension of each class label. It is worth noting that the embedding module is trained end-to-end with the whole model. Hence, the class label representations are learned from scratch without the need of any pre-extracted/pre-trained or transfer learning.

Given an input class label, we trim the input to a maximum of 44 words. The class label that is shorter than 44 words is zero-padded. Each word is then represented by a 300300-D word embedding. This step results in a sequence of word embeddings with a size of 4×3004 \times 300 and denotes as E^i\hat{E}_i of ii-th class label in CkC_k class label set. In order to obtain the dependency within the class label name, the E^i\hat{E}_i is passed through a Gated Recurrent Unit (GRU), which results in a 10241024-D vector representation EiE_i for each input class. Note that, although we use the language modality (i.e., class label name), it is not extra information as the class label name and the class label identity (for calculating the loss) represent the same object category.

Joint Embedding

This module takes the feature map F\textbf{\textit{F}} and the top-k class embedding Ek\textbf{E}_k as the input to produce the joint representation JRdj\textbf{\textit{J}} \in \mathbb{R}^{d_j} and the attention map. We first flatten F\textbf{\textit{F}} into (df×f)(d_f \times f), and Ek\textbf{E}_k is into (de×k)(d_e \times k). The joint representation J\textbf{\textit{J}} is calculated using two modalities F\textbf{\textit{F}} and Ek\textbf{E}_k as follows:

JT=(T×1vec(F))×2vec(Ek)\textbf{\textit{J}}^T= \left(\mathcal{T} \times_1 \text{vec}(\textbf{\textit{F}}) \right) \times_2 \text{vec}(\textbf{E}_k)

where TRdF×dEk×dj\mathcal{T} \in \mathbb{R}^{d_{\textbf{\textit{F}}} \times d_{\textbf{E}_k} \times d_j} is a learnable tensor; dF=(df×f)d_{\textbf{\textit{F}}} = (d_f \times f); dEk=(de×k)d_{\textbf{E}_k} = (d_e \times k); vec()\text{vec}() is a vectorization operator; ×i\times_i denotes the ii-mode tensor product.

In practice, the preceding T\mathcal{T} is too large and infeasible to learn. Thus, we apply decomposition solutions that reduce the size of T\mathcal{T} but still retain the learning effectiveness. We rely on the idea of the unitary attention mechanism. Specifically, let JpRdj\textbf{\textit{J}}_p \in \mathbb{R}^{d_j} be the joint representation of pthp^{th} couple of channels where each channel in the couple is from a different input. The joint representation J\textbf{\textit{J}} is approximated by using the joint representations of all couples instead of using fully parameterized interaction as in Eq.~\ref{eq:hypothesis}. Hence, we compute J\textbf{\textit{J}} as:

J=pMpJp\textbf{\textit{J}} = \sum_p \mathcal{M}_p \textbf{\textit{J}}_p

Note that in Equation above, we compute a weighted sum over all possible couples. The pthp^{th} couple is associated with a scalar weight Mp\mathcal{M}_p. The set of Mp\mathcal{M}_p is called the attention map M\mathcal{M}, where MRf×k\mathcal{M} \in \mathbb{R}^{f \times k}.

There are f×kf \times k possible couples over the two modalities. The representation of each channel in a couple is Fi,(Ek)j\textbf{\textit{F}}_{i}, \left(\textbf{E}_k\right)_{j}, where i[1,f],j[1,k]i \in [1,f], j \in [1,k], respectively. The joint representation Jp\textbf{\textit{J}}_p is then computed as follows

JpT=(Tu×1Fi)×2(Ek)j\textbf{\textit{J}}_p^T= \left(\mathcal{T}_{u} \times_1 \textbf{\textit{F}}_{i} \right)\times_2 \left(\textbf{E}_k\right)_{j}

where TuRdf×de×dj\mathcal{T}_{u} \in \mathbb{R}^{d_f \times d_e \times d_j} is the learning tensor between channels in the couple.

From Equation above, we can compute the attention map M\mathcal{M} using the reduced parameterized bilinear interaction over the inputs F\textbf{\textit{F}} and Ek\textbf{E}_k. The attention map is computed as

M=softmax((TM×1F)×2Ek)\mathcal{M} = \text{softmax}\left(\left(\mathcal{T}_\mathcal{M} \times_1 \textbf{\textit{F}} \right) \times_2 \textbf{E}_k \right)

where TMRdf×de\mathcal{T}_\mathcal{M} \in \mathbb{R}^{d_f \times d_e} is the learnable tensor.

The joint representation J\textbf{\textit{J}} can be rewritten as

JT=i=1fj=1kMij((Tu×1Fi)×2(Ek)j)\textbf{\textit{J}}^T= \sum_{i=1}^{f}\sum_{j=1}^{k} \mathcal{M}_{ij} \left( \left( \mathcal{T}_{u} \times_1 \textbf{\textit{F}}_{i}\right) \times_2 \left(\textbf{E}_k\right)_{j} \right)

It is also worth noting from Equation above that to compute J\textbf{\textit{J}}, instead of learning the large tensor TRdF×dEk×dj\mathcal{T} \in \mathbb{R}^{d_{F} \times d_{\textbf{E}_k} \times d_j}, we now only need to learn two smaller tensors TuRdf×de×dj\mathcal{T}_{u} \in \mathbb{R}^{d_{f} \times d_{e} \times d_j} in Eq.~\ref{eq:couplecompute} and $\mathcal{T}\mathcal{M} \in \mathbb{R}^{d_f \times d_e}$.

Self Assessment

The joint representation J\textbf{\textit{J}} from the Joint Embedding module is used as the input in the Self Assessment step to obtain the 2nd2^{nd} top-k predictions Ck\textbf{C}'_k. Note that Ck={C1,...,Ck}\textbf{C}'_k = \{C'_1,..., C'_k\}. Intuitively, Ck\textbf{C}'_k is the top-k classification result after self-assessment. This module is a fine-grained classifier that produces the 2nd2^{nd} predictions to reassess the ambiguity classification results.

The contribution of the coarse-grained and fine-grained classifier is calculated by

Pr(ρ=ρi)=αPr1(ρ=ρi)+(1α)Pr2(ρ=ρi)\text{Pr}(\small{\rho} = \small{\rho}_i) = \alpha \text{Pr}_1(\small{\rho} = \small{\rho}_i) + (1- \alpha) \text{Pr}_2(\small{\rho} = \small{\rho}_i)

where α\alpha is the trade-off hyper-parameter (0α1)\left(0 \leq \alpha \leq 1\right). Pr1(ρ=ρi),Pr2(ρ=ρi)\text{Pr}_1(\small{\rho} = \small{\rho}_i), \text{Pr}_2(\small{\rho} = \small{\rho}_i) denotes the prediction probabilities for class ρi\small{\rho}_i, from the coarse-grained and fine-grained classifiers, respectively.

Next

In the next post, we will verify the effectiveness and efficiency of the method.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 5)

In the previous post, we have discussed about offline phase in details. In this post, we will discover the online phase in details.

1. States in Online Phase

The online phase of the CFOR system corresponding to the demonstration in Figure.1

Figure 1. Online stage of the CFOR system.

2. Technical Details

The used functions will be described as follows:

  • (i) detector(imgQuery): an object in an image is automatically detected by using a trained detector. In this function, we inherit the successful software YOLO (version 3.0) to identify fashion items. Besides, the items identified are also refined by the region identification model, which is trained by “classifyModel” function.
  • (ii) infor_extract(states, obj, onto, classifyModels, multitaskModel): for each query object, all attribute learning models trained in function "multitaskModel" and coarse classification models in function "classifyModel" are run. We extract the region ⟶ category ⟶ attributes and necessary features for each stage of the ontology.
  • (iii) query_expansion(infor, feat): query expansion based on the mean vector is used for reranking retrieval results.
  • (iv) compute_sim_score(database, infor, feat): for each pair of features, asymmetric distance is used to measure the dissimilarity distance between the query and the sample in the database
  • (v) ranking(scorelist, database, top__k, GPU_search = True): based on the score between the query and all samples in the database obtained from function "compute_sim_score," ranking is applied; smaller is better.
  • (vi)retrieval(indexes, score_list, database, GPU_search = True): the retrieval process contains 3 steps including feature retrieval, fine-grained retrieval, and query expansion. For global retrieval, global features of the query object obtained from function “inforextract” and the features of samples in the database are passed to function "ranking" to get 1st top-_m retrieval results. For fine-grained retrieval, attribute features of the query object obtained from function “inforextract” and the features of samples in 1st top-_m retrieval results are passed to function "ranking" to get 2nd top-k retrieval results. For query expansion, the mean vector is computed from 2nd top-k retrieval results, and each feature of 2nd top-k retrieval results is passed to function "ranking" to get final top-k retrieval results, i.e., query expansion-based reranking.

As described in Figure 1, the online phase of the CFOR system contains three stages which will be put into use in real time. They are given as follows.

3. Prediction Stage

This stage will take advantage of object ontology and classification models obtained from the offline phase and then makes predictions from coarse to fine for each query image:

Fine-grained information in terms of regions, categories, and attributes provides more options for a customer to give a full semantic query. The object will be predicted from coarse to fine. In turn, the region, category, and attribute will be predicted based on object ontology and a local MDNN. The object retrieval system uses extracted semantic information as the category and attribute to search in detail.

3. Dissimilarity Measuring Stage

This stage will take advantage of the database as well as the indexing file from the offline phase and a dissimilarity measure to get scores and then rank, rerank, and release retrieval results for each query image. This stage is based on the dissimilarity measure between attribute vectors of query images and database images:

Based on combination of K-nearest neighbour search in terms of L2 distance and asymmetric distance computation, we take advantage of parallel processing by GPU through the Faiss method to compute the distance from the query image to the necessary one in the database. The distance which is also called the score of each image in the database is then sorted to rank the dissimilarity. The smaller the score of the image, the more similar the query. Based on the number of retrieval images required or thresholds, we will have an appropriate cutoff in the score as well as the number of retrieval images. This kind of measurement is used to compute distance for both deep features vectors and attribute vectors.

4. Dissimilarity Measuring Stage

Query expansion is a technique that can help gather additional relevant information from the input to increase retrieval performance. The information can be relevant images, additional features, description, etc. based on the query expansion algorithms and data. In this stage, we would like to take advantage of the previous retrieval results and then expand the query by using the mean vector to rerank and get reranked retrieval results to improve retrieval performance.

Query expansion based on the mean vector is chosen among many methods, the mean vector computed from features of retrieval results and the features of input help reduce the bias between different considered features. Thus, the CFOR system can eliminate unrelated features; that is, retrieval features have high gap from the mean vector features, which helps reduce outliers and rise the precision score.

Query expansion based on computing mean vector is performed very fast, and it can take advantage of the Faiss similarity searching method as well. Query expansion can remove outliers, thanks to the statistic essence of the mean vector.

Next

In the next post, we will mention the Fashion Ontology, a CFOR System Testing in Fashion.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 4)

In the previous post, we have discussed about offline phase and online phase in overview. In this post, we will discover the offline phase in details.

1. States in Offline Phase

This phase consisted of three substages:

  • Object Ontology Establishment Stage. This stage defines fashion ontology to control the training flow as well as the online retrieval flow which serves as a bridge between high-level concepts (objects and categories), midlevel concepts (attributes), and raw data.
  • Learning Stage. This stage exploits deep networks with transfer learning in dealing with the specific tasks including object part learning, category learning, and attribute learning.
  • Storing and Indexing Stage. This stage defines a way of storing data as well as making the index list to reduce retrieval or searching time.

From the offline phase, in this section, inherited from previous state-of-the-art methods, we will mention about object part extraction, transfer learning, and its role in the retrieval system as well as data storing. These modules are highly generalized to any object. Other issues including ontology, attribute learning, network architecture, and indexing strategy will be detailed in the following sections.

2. Loss Function

This function inherited the current state-of-the-art ResNet for classification, and cross entropy loss function is applied for multiclass classification in the category classification model and attribute classification model.

For attribute multitask classification models, the loss function is described as follows:

3. Technical Details

Object ontology which is designed manually based on professional experience and public dataset for the community. It is organized into the hierarchical semantic tree with three main levels: region level, category level, and attribute level. Regions, categories, and attributes are learned automatically based on the local MDNN. The DeepFashion dataset, the used functions will be described as follows:

  • (i)extract_predicates(dta): in a rich-annotated dataset, e.g., DeepFashion, a sample image can be annotated by many labels in different fine-grained levels. For each fine-grained level, the function is used to extract the unique possible labels of samples and then store these labels into a corresponding array. For example, in the DeepFashion dataset, Top, Bottom, and Body are unique labels belonging to one fine-grained level, and thus, they are stored into one array. Similarly, fabric, shape, part, style, and texture labels belong to one fine-grained level and are stored into one array.
  • (ii)build_ontology(predicates, prior): this matches the extracted level and its labels from each predicate array into the corresponding stage of the general ontology, i.e., prior. For example, Top, Bottom, and Body belong to one level which is matched with the region stage of the ontology. After the matching is finished, all other unused stages are eliminated from the general ontology to generate the adapted ontology, e.g., fashion ontology.
  • (iii)extract_state(onto): from the built ontology, all stages and their labels are searched and stored into arrays which will be used to reconstruct the data. For example, the region stage array contains three classes, and the category stage array contains 50 classes.
  • (iv)extract_nes_dta(dta, state, onto): based on the stage and the classes extracted from the “extract_state” function, the whole DeepFashion dataset will be split. Only samples having the labels belonging to the stage are stored as the training set of that stage in the ontology. For example, with the region classification model, only samples labelled Top, Body, or Bottom are used for training.
  • (v)classifyModel(architecture, state_dta): in the DeepFashion dataset, based on ontology, there are four classification models: region classification model and category classification model for the Top region, Body region, and Bottom region. These models are retrained from the ImageNet dataset using ResNet-10.
  • (vi)multitaskModel(group_state_dta, architecture, Matthrew_coef = True): for each group state in terms of the fine-grained attribute level, a multitask classification model is built, e.g., fabric attribute group classification model and style attribute group classification model. These models are retrained from the ImageNet dataset using NASNet v3. Besides, the attribute learning and the usage of MCC are mentioned for an imbalanced data solver.
  • (vii)indexing(state_sta): indexing files are created that will be used for run-time retrieval. The method is based on the nonexhaustive compressed-domain search with GPU.
  • (viii)build_storage(onto, states): storage structures are automatically created based on built-in object ontology and extracted states.
  • (ix)infor_extract(states, dta, onto, classifyModels, multitaskModel): for each sample in the database, all attribute learning models trained in “multitaskModel” function are run and then all possible attributes which are higher than thresholds are extracted.
  • (x)feat_extract(dta, onto, classificationModels, multitaskModel): for each sample in the database, the features of the pre-softmax layer in four models trained in “classifyModel” function are obtained.
  • (xi)structure(storage, feat_dta, info_dta, indexFiles): the database is automatically built based on extracted features, extracted information, index, and storage structure.

3.1. Object Part Extraction

For the aforementioned reasons, foreground objects should be extracted from background regions efficiently and accurately before entering the retrieval step. The target of object extraction is to filter the necessary specific subjects. This also improves the efficiency of the following modules as well as increases the overall system performance. There are many successful object detection methods. Among them, YOLO shows the state-of-the-art results. In our system, we inherited the successful software YOLO (version 3.0) to identify fashion items.

3.2. Transfer Learning

Transfer learning is one of the best methods to reduce training time, especially with complicated architectures such as ResNet or NASNet. The key issue is the initial parameters. In the first step of the training process, we have to generate these parameters with some unsupervised learning methods. However, the initial one will be far from the optimal one. In transfer learning, we will reuse the trained parameters on a large and diverse dataset (such as ImageNet dataset). By this way, our training process will be easier to meet convergence. Thus, it reduces the training time.

Transfer learning can be applied in different ways based on the size of the dataset and data similarity. There are four scenarios in total. First, if the data size is small while data similarity is high, we use the pretrained model as a feature extractor. Second, if the data size is small and data similarity is low, we freeze the top layers and train the remaining layers of the pretrained model. Third (ideal situation), if the data size is large and data similarity is high, we can retrain the model by using the weights initialized in the pretrained one. Fourth (worst situation), if the data size is large and data similarity is low, transfer learning cannot be applied, and we have to train our model from scratch. In our fashion example experiments, while DeepFashion is a large dataset and ImageNet (dataset used for transfer learning) is a high diversity one, we can use all of the initialized weights from the pretrained model.

According to our approach, transfer learning will be applied in region, category classification as well as attribute learning along with ResNet and NASNet architectures, respectively. It can also be used in global deep feature extraction to improve the overall retrieval performance.

3.3. Data Storing

Features extracted from the category classification task and attribute learning will be stored in a hierarchical semantic tree based on object ontology. All features belong to a leaf of object ontology and will be stored in one file. In case of the expansion of large-scale data, the mentioned files can be indexed and split with a corresponding mapping key for each image. The folders will be organized based on object ontology in which each name corresponds to each concept. To clarify, data storing for the proposed ontology is defined as follows (see Figure below for an example of data storing):

  • (i) All files are stored in a folder named “database,” which is denoted as the “Object” node.
  • (ii) Based on ontology, “Object” node contains 3 nodes at the “Region” semantic level. Thus, we have 3 smaller folders: “Top,” “Body,” and “Bottom.”
  • (iii) At the next stage of ontology, we have the “Category” semantic level. Thus, we have 50 folders representing all nodes of “Category.”
  • (iv) Finally, we have the “Attribute” semantic level standing for the leaf node state in ontology. At this state, all features belong to the same “Region” and “Category” and are stored in one file.

Figure 1. An example of the storing structure.

Next

In the next post, we will mention the online phase in details.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 3)

In the previous post, we have discussed about the imbalance data problem annd the overview of the object retrieval system. The CFOR system is organized into two main phases: offline phase and online phase. In this post, we will discover the overall of the online phase and offline phase of the proposal.

1. Overview of Offline Phase

This phase is designed to generate object ontology, database, indexing file and region detection model, category classification model, and attribute classification model.

  • Object ontology is designed manually based on professional experience and public dataset for the community. It is organized into a hierarchical semantic tree with three main levels: region level, category level, and attribute level.
  • The database is generated to store the preextracted features, regions, categories, and attributes of all images in the dataset. It supports to reduce the online retrieval time and provides the necessary semantic information for each retrieved image.
  • The indexing file which is created to support fast mapping in the online phase of the CFOR retrieval system is the key to perform the retrieval task at runtime.
  • Regions, categories, and attributes are learned automatically based on the local MDNN. Detection models and classification models are created to extract or predict semantic information of the query image and dataset such as regions, categories, and attributes.

2. Overview of Online Phase

This phase of the CFOR system is designed to run the retrieval process including object detection, semantic information extraction, and query expansion and retrieval.

In the object detection stage, we use the trained object detector to detect objects in the query image. In the semantic information extraction stage, the built-in object ontology and classification models are used for extracting the necessary semantic information of each identified object. The extracted semantic information and deep global features of each detected object passed through the searching system along with the indexing file to quickly compute the score between the query object and the sample in the database. Retrieval is applied to rank and export the most similar images to the query object and their relevant information. Query expansion is optional and used to increase the retrieval performance with a trade-off for retrieval time.

The power of mutually supporting object ontology, local MDNN, and imbalanced data solver in the CFOR system: Figure.1 shows the operation of the CFOR system with the interaction of the three main modules object ontology, a local MDNN, and an imbalanced data solver to optimize the learning strategy and improve the overall retrieval performance on large-scale datasets.

Figure 1. Synthesis of object ontology, deep learning, and imbalanced data problem solver in the CFOR system.

Object ontology supports conducting the training flow (with a local MDNN) and retrieval flow (from the coarse-grained level to the fine-grained level) to save computational costs in the training stage and retrieval stage on large-scale datasets. Training flow also paves a way for applying transfer learning which may improve the convergence rate of deep networks. Object ontology which could transform the global imbalance of data into local imbalance of data based on fine-grained groups makes the imbalanced data problem easier to deal with.

Deep multitask NN supports to link the object ontology to the raw data effectively at the category level and attribute level by exploiting inner-group correlations and intergroup correlations. The object ontology supports to update the system at the local level with parallel processing based on the local MDNN. Therefore, CFOR is updated in a flexible manner on large-scale datasets with many variations.And the proposed imbalanced data solver based on MCC which addresses data imbalance has contributed effectively to increasing the quality of object ontology implementation without adjusting network architecture and data augmentation.

Algorithm and demonstration of the CFOR system: an online phase and offline phase (Figures 2 and 3) are used to analyze tasks in the CFOR system. These phases will be demonstrated in detail in this section.

Figure 2. Offline stage of the CFOR system.

Figure 3. Online stage of the CFOR system.

Besides, the CFOR system can be put into use as a general solution for retrieval. To evaluate the performance of the proposed system, fashion objects with attributes are selected in experiments.

Next

In the next post, we will mention the offline phase in details.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 2)

In the previous post, we have discussed about multitask learning and object retrieval system. In this post, we will discover imbalance data problem and the main proposal.

1. Imbalanced Data Problem

Imbalanced data are the problem in machine learning in which the class distribution is not uniform between the classes. Usually, they are composed of two types of classes: the majority classes (positive) and the minority classes (negative). Recent research in machine learning shows that using an uneven distribution of class examples during learning can cause learning algorithms with misleading performance (bias). It means a classifier with high accuracy in the majority, but it gives poor accuracy in the minority class. In the case of attribute learning, an imbalance occurs if the number of instances in some attributes varies significantly in quantity compared to other attributes. To deal with this situation, in general, adjusting the distribution of classes is an essence of many popular methods to handle imbalanced data problems.

  • Data sampling: sampling-based methods such as upsampling, downsampling, or data augmentation are considered to be a solution for imbalanced data problems. In addition to making data more balanced, they can help reduce training time (downsampling) or make the learning process more efficient (upsampling). The best approach we know is SMOTE which can solve the situation by automatically generating additional data (upsampling) based on the original dataset. However, these methods increase overfitting when training (upsampling) or losing (downsampling) data. Data augmentation is proved to be robust in dealing with imbalanced training data. However, this method takes up a lot of training resources, and it is difficult to find a proper augmented dataset which is large enough to train. And it is very difficult (or impossible) to augment data to balance the attributes in datasets because each object usually has many attributes.

    Figure 1. A simple data sampling

  • Architecture, loss function, and metric configuration: other methods exploit network architectures, loss functions, or metrics to address the imbalanced data problem when training. The methods (at the algorithm level) enhance the existing classifier by adjusting algorithms to recognize the smaller classes. Internal techniques provide general solutions for the imbalanced data problem because these are not specific to particular problems. These approaches show better performance compared to data sampling; however, they are often difficult to implement as well as configure in the future. Therefore, they are not always the best choice in dynamic retrieval systems in which the attributes have a large variety.Threshold and output-based configuration: instead of generating more data or making changes in the model, these methods find the best thresholds based on output. The essence of these methods is to use scores that show the probability to indicate which test sample is a member of a class in producing several learners by changing the threshold for class members. These methods are particularly effective in resolving imbalanced data problems without changing the configuration in the model. Moreover, they also do not reduce data or increase overfitting. SVM is proposed to find these thresholds. However, Boughorbel et al. proposed Matthews’ correlation coefficient (MCC) to deal with imbalanced data in classification. Although SVM shows better performance, MCC consumes less resources and processing time compared to it. Based on the methods of many other researchers, we found a solution for multitask learning that is suitable to retrieval systems using the end-to-end DCNN for training and MCC for estimating thresholds to get final outputs.

Figure 2. Metric losses

2. Materials and Methods: CFOR System

The CFOR system is very complicated but easy to understand. We focus on the main points of the CFOR system.

CFOR is an object retrieval system integrated by object ontology, a local MDNN (NASNet and ResNet), and an imbalanced data solver (MCC) to improve the performance of the large-scale object retrieval system from the coarse-grained level (categories) to the fine-grained level (attributes) (see Figure below).

Figure 2. Synthesis of object ontology, deep learning, and imbalanced data problem solver in the CFOR system.

Query Image. For traditional content-based image retrieval systems, with query images, one is just able to retrieve the images ranked on visual similarity to query image. It is very difficult (or impossible) for users to provide semantic information to the system based on query images. But the interesting thing is that, in our CFOR system, this challenge has been solved. The semantic information of the query image is extracted automatically by the category and attribute classification system, and users can use the extracted semantic information during the retrieval process.An example is how users can query “Asian face” with only a query image; here, “Asian race” is semantic information. The traditional retrieval methods cannot meet this requirement because of the curse of semantic gap. And the CFOR system can recognize “Asian race” and use it to retrieve. Another example for “Fashion” object based on our CFOR system is described in Figure 3.

Figure 3. Extracting regions, categories, and attributes from a query image with trained models of the CFOR system. After that, users can use this semantic information to reduce the searching space.

From the query image, based on fashion ontology, the detector quickly identifies the region (Top and Bottom; see Figure 4).

Figure 4. Fashion ontology used to retrieve.

After that, the user selects the region (Top; see Figure 5); the CFOR system quickly identifies the category related to the Top region (category: Blazer). Later, specific concepts and visual concepts are extracted according to Blazer, and users can select some of them (or all of them) to retrieve. For user-friendly interaction, only extracted regions, categories, and attributes are shown. Other information such as global deep features, attribute vector, ontology, or group of attributes which are used as searching input of the system will not be displayed. In such a way, users can order the CFOR system at the semantic level, and they can achieve the results that match both the content and semantics of the query image.

Figure 5.Retrieval results.

Next

In the next post, we will mention the online phase and offline phase of the proposed retrieval system.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 1)

Object retrieval plays an increasingly important role in video surveillance, digital marketing, e-commerce, etc. It is facing challenges such as large-scale datasets, imbalanced data, viewpoint, cluster background, and fine-grained details (attributes). This paper has proposed a model to integrate object ontology, a local multitask deep neural network (local MDNN), and an imbalanced data solver to take advantages and overcome the shortcomings of deep learning network models to improve the performance of the large-scale object retrieval system from the coarse-grained level (categories) to the fine-grained level (attributes). Our proposed coarse-to-fine object retrieval (CFOR) system can be robust and resistant to the challenges listed above. To the best of our knowledge, the new main point of our CFOR system is the power of mutual support of object ontology, a local MDNN, and an imbalanced data solver in a unified system. Object ontology supports the exploitation of the inner-group correlations to improve the system performance in category classification, attribute classification, and conducting training flow and retrieval flow to save computational costs in the training stage and retrieval stage on large-scale datasets, respectively. A local MDNN supports linking object ontology to the raw data, and an imbalanced data solver based on Matthews’ correlation coefficient (MCC) addresses that the imbalance of data has contributed effectively to increasing the quality of object ontology realization without adjusting network architecture and data augmentation. In order to evaluate the performance of the CFOR system, we experimented on the DeepFashion dataset. This paper has shown that our local MDNN framework based on the pretrained NASNet architecture has achieved better performance (14.2% higher in recall rate) compared to single-task learning (STL) in the attribute learning task; it has also shown that our model with an imbalanced data solver has achieved better performance (5.14% higher in recall rate for fewer data attributes) compared to models that do not take this into account. Moreover, MAP@30 hovers 0.815 in retrieval on an average of 35 imbalanced fashion attributes.

1. Introduction

Nowadays, object retrieval is facing some challenges and has some advantages.

Query format plays a very important role in large-scale object retrieval systems. Thus, the query format should be user-friendly and satisfy user requirements in practice.

Two query formats are popular these days: image-based format and text-based format. The text-based query format is being used widely in many searching systems. However, in many cases, it is very difficult to use query text to express the content that human would like to retrieve because words have some limitations in expressing visual information. Instead, a query image is worth more than thousand words; it allows customers to search objects without typing, and the most important thing is that it can retrieve the results based on content. Nevertheless, the limitations of the query image in expressing semantic information could decrease the overall retrieval performance. Thus, the query image and retrieval image with useful related information (regions, categories, fine-grained attributes, etc.) will be the interesting points that we have to focus on to improve the performance of the coarse-to-fine object retrieval system.

Object retrieval systems should meet the requirements of retrieving from large-scale datasets not only at the coarse level but also at the detailed level (or attribute level). For example, in face retrieval systems, facial attribute retrieval is often required. In fashion retrieval systems, fashion attribute retrieval is an indispensable requirement. In person reidentification systems, in the reidentification stage, besides using the global features of the whole human body, attribute vectors of the face and clothes are also being exploited effectively. In crowd attribute recognition systems, the useful attribute set consisted of location, participants, and activities.

Objects often have multiple attributes, and there are methods to retrieve objects at the attribute level from large-scale datasets without manual annotation. In attribute recognition, the traditional methods often waste a lot of time in selecting hand-crafted features for each attribute group during the trial-and-error process but do not always achieve the desired results. In recent years, the deep convolutional neural network (DCNN) has demonstrated high performance in many tasks in computer vision such as detection, classification, recognition, and retrieval. And without exception, the DCNN is also used for attribute learning, with only one network architecture, and the DCNN model can learn to recognize many attributes.

The performance of the DCNN-based attribute learning model will not achieve high rate if the set of attributes plays the same role in the network architecture at the output level and imbalanced data are unresolved. To exploit the inner-group correlations in coarse-grained groups or fine-grained groups, the DCNN often is revised to the deep multitask NN. The performance of classification will be improved if the elements of fine-grained category groups or fine-grained attribute groups could share similar learning features, so the slope of their error surface will become more uniform and the deep multitask learning algorithm can easily reach the global optimum effectively.

Object ontology plays an important role in category classification, attribute classification, and conducting training flow and retrieval flow to save computational costs in the training stage and retrieval stage on large-scale datasets, respectively. Thus, based on our experience in researching objects related to attributes such as face, cloth, person (reidentification), crowd (monitoring), and fast filters in large-scale object retrieval, we would like to introduce an object ontology as a hierarchical semantic tree with three levels: region, category, and attribute levels. The attribute level consisted of visual concepts and specific concepts. Visual concepts support linking common visual attributes to arbitrary objects.

We introduce an object ontology based on popular large-scale standard datasets in science community, so we hope that our ontology can meet the criterion “widely recognized in community.” And for criterion “realization,” we have proposed the local MDNN to support linking object ontology to the raw data. However, if object ontology could not be linked with high quality, it could not function effectively. And we have proposed the imbalanced data solver based on MCC to address data imbalance that has contributed effectively to increasing the quality of linking object ontology to raw data without adjusting network architecture and data augmentation.

We review some typical works based on object ontology, deep multitask neural networks, and imbalanced data solvers to highlight our contributions.

Most of the works only present the set of attributes in the form of item lists or item groups. A few works used the terminology "ontology", but to the best of our knowledge, there are not works that present the object ontology in full meaning of regions, categories, and attributes.

In [8], FashionNet handles the challenges as deformation and occlusions by explicitly predicting clothing landmarks and pooling features over the estimated landmarks, resulting in more discriminative cloth representation. The authors do not use the terminology "ontology", but the DeepFashion dataset is organized based on a hierarchical tree; it is only deployed according to fashion, and it includes a two-level tree: the first level consisted of 50 categories and the second level consisted of 5 attribute groups (texture, fabric, shape, part, and style) (it does not have color attribute). The coarse-grained groups (at the category level) or fine-grained groups (at the attribute level) have the same role in deep neural networks, and the imbalanced data solver has not been considered yet.

Our idea is to improve the performance of deep neural networks based on object ontology and imbalanced data solvers with inspiration from Gödel’s incompleteness theory. This theory shows the limitation of any consistent formal system as well as the limitation of specific methods in solving problems. When the deep network configuration method is not able to create such a large effect as in the early days it took place, it is necessary to integrate object ontology and imbalanced data solvers into deep learning. Based on appropriate interventions in inputs and outputs, we introduce a new method that can help improve the performance of the object retrieval system.

The main contributions of this paper are as follows.

  • Our proposed unified model consisted of object ontology, a local MDNN, and an imbalanced data solver to improve the performance of the large-scale object retrieval system from the coarse-grained level (categories) to the fine-grained level (attributes).

  • Our proposed object ontology is a hierarchical semantic tree consisting of three main levels: region, category, and attribute levels. It can support the optimal learning strategy and minimize the effect of semantic gap. It is useful to improve the performance of category classification, attribute classification, and conducting training flow and retrieval flow to save computational costs in the training stage and retrieval stage on large-scale datasets, respectively.

  • Our proposed local MDNN is inspired by multitask neural networks. It is based on NASNet, ResNet exploiting the local multitask neural network architecture, to improve the performance of category classification and attribute classification and for flexible system updates. The local MDNN supports linking object ontology to raw data and takes advantage of inner-group correlations of categories and attributes. If the inner-group correlations (or intergroup correlations) are exploited, the performance of classification will be improved because the elements of fine-grained categories or the fine-grained attribute group share similar learning features, the slope of their error surface becomes more uniform, and our deep local multitask learning algorithm can easily reach the global optimum effectively.

Data imbalances often occur for large-scale datasets. Data augmentation is almost impossible because each object can have multiple attributes. The solution based on the loss functions, as in [6], may be possible, but it cannot exploit transfer learning. Our proposed imbalanced data solver is inherited from MCC without adjusting network architecture and data augmentation. It is integrated into the local MDNN to improve the performance of category classification and attribute classification, but it can still exploit transfer learning to reduce computational costs in the training stage on large-scale datasets.

Our proposed query format is based on object ontology with semantic information such as regions, categories, and attributes extracted automatically from the query image. Therefore, we can express semantic information from the image to the retrieval process that the traditional methods have not implemented yet.

Figure 1. A usecase of object retrieval system.

2. Object Retrieval System

Fine-grained object retrieval is supposed to search for similar images that include specific object attributes. It declares a transition model from image retrieval to object attribute retrieval. Specifically, unlike traditional image retrieval systems where queries and results are often coarse (e.g., texts or images), fine-grained image retrieval aims to retrieve semantic information such as categories and attributes. In the fashion field, taking advantages of semantic information, an object retrieval method based on the combination of the global feature with fine-grained attribute information was introduced [8]. Inspired by previous works, we would like to propose a coarse-to-fine object retrieval system which not only takes advantage of the combination of the global feature with fine-grained attribute information but also optimizes the learning strategy based on ontology and resolves the imbalanced data problem by interfering with the output.

In addition to meeting the semantic retrieval results, the object retrieval system must handle large-scale problems to run in real time. However, most solutions did not take advantage of the power of GPUs for parallel processing which can significantly reduce feature-matching time and retrieval time. To leverage the support of GPUs, we inherited the search algorithm introduced by Johnson et al. (billion-scale similarity search with GPUs) which is a nonexhaustive similarity search. The search method perfectly suited the proposed CFOR system which further decreased searching time by creating multi-index files based on built-in object ontology.

Figure 2. Original object retrieval system.

3. Attribute Learning

Attribute learning is a backbone of CFOR, and it has strong effects on performance of fine-grained object retrieval. Therefore, attribute learning is considered one of the important parts of the learning strategy.

Attribute Learning.

This method is used for object recognition systems at the fine-grained level. Unlike learning methods that are used for the high-level concept, attribute learning supports a solution for midlevel semantic concepts or visual concepts which are known to have (more or less) correlations to each other. There are two main different learning methods: single-task learning and multitask learning.Single-task attribute learning: in this type, attributes have their own learning model. Therefore, it leads to the number of models equal to the number of attributes. Moreover, each attribute is treated separately, for which the inner-group correlations are not yet exploited. Many works are known in the fashion field by using single-task learning for fashion attributes. At that time, there were many challenges in multitask learning. A shared CNN is defined to pave a way in the final format of the multitask multilabel predictions. Therefore, multitask learning becomes possible.Multitask attribute learning: to apply this technique to attributes, samples will be collected by merging given datasets into one with one-hot binary vector demonstration. Like single-task learning, the input will be the image. Despite the output of single-task learning which is a value that describes the existence (or not) of an attribute in an image, the output of multitask learning will be a one-hot binary vector which describes the existence (or not) of a group of attributes. Rudd has shown that joint optimization over all attributes outperforms training a single independent network with the same architecture for each attribute, in which the feature space is optimized along with the classifier on a per-attribute basis, both in terms of accuracy and storage, processing efficiency. This result shows that the multitask approach is much more effective in exploiting latent correlations than independent classifiers used to learn them. Although multitask learning can yield better performance compared to single-task learning, its critical weakness is that the model cannot be reused when there is any attribute change. A retraining or additional model will be applied when a new attribute is added. Lack of reuse is the reason that multitask learning methods are not flexible for attributes that change frequently. To address these challenges, we propose that local multitask attribute learning be considered a grouping method based on object ontology to improve its reuse.

Figure 3. Attribute learning model based on deep features with SVM classifiers.
Figure 4. Attribute learning model based on adaptive attribute domain with independent deep convolutional neural networks.
Figure 5. Attribute learning model based on the end-to-end deep neural network as a shared block with adaptive loss function.

Next

In the next post, we will mention Imbalanced Data Problem and our proposal in details.

Controllable Group Choreography using Contrastive Diffusion (Part 6)

In previous part, we have discovered the qualitative and quantitative expriments. In this part, let investigate ablation studies and user studies.

Our code can be found at: https://github.com/aioz-ai/GCD

1. Ablation Analysis

Loss Terms. The contribution of the geometric loss Lgeo\mathcal{L}_{\rm geo} and contrastive loss Lnce\mathcal{L}_{\rm nce} in GCD is thoroughly analyzed and presented in Table 1. The results demonstrate that both losses play a crucial role in enhancing the overall performance across all four evaluation protocols. In particular, it can be seen that the effect of Lgeo\mathcal{L}_{\rm geo} on realism metrics (FID, GMR, and PFC) is significant. This observation can be attributed to the fact that this loss improves the physical plausibility and naturalness of the dance motions, empirically mitigating common artifacts such as jittery motion or foot skating. By enforcing the geometric constraints, GCD can generate faithful motions that are on par with real dances. Moreover, the contrastive loss Lnce\mathcal{L}_{\rm nce} contributes positively to the favorable results in the synchrony measures (GMR and GMC). This loss term encourages the model to synchronize the movements of multiple dancers within a group, thus improving the harmonious coordination and cohesion of the generated choreographies. In general, the results of Lgeo\mathcal{L}_{\rm geo} and Lnce\mathcal{L}_{\rm nce} validate their importance across various evaluation metrics.

Table. 1. Global module contribution and loss analysis. Experiments are conducted on GCD with γ=0\gamma = 0 (neutral mode).

Group Global Attention. Results presented in first two lines of Table 1 demonstrate substantial improvements obtained by incorporating Group Global Attention into GCD. It clearly shows that without the Group Global Attention, the performance on the group dance metrics (GMR and GMC) is significantly degraded. We also observe that the removal of this block resulted in inconsistent movements across dancers, where they seem to dance in freestyle without any group choreographic rules and collide with each other in many cases, although they may still follow the rhythm of the music. Results suggest the vital importance of ensuring coherency and regulating collisions for visually appealing group dance animations.

2. User Study

Qualitative user studies are important for evaluating generative models as the perception of users tends to be the most relevant metric for many downstream applications. Therefore, we conduct user studies to evaluate our approach in terms of group choreography generation. We organized two separate studies and enlisted roughly 50 individuals with diverse backgrounds to participate in our experiment. Each participant should have some relevant experience in music and dance (at least 1 month of studying or working in dance-related professions). The age of participants varied between 20 and 50, with approximately 55\% female and 45\% male.

In the initial study, we requested the participants to evaluate the dancing animations based on three criteria: the naturalness of the dancing motions (Realism), how well the movements match the music (Music-Motion Correspondence), and how well the dancers interact or synchronize with each other (Synchronization between Dancers). Participants were asked to rate scores from 0 to 10 for each criterion, ranging from 0-very poor, 5-acceptable, to 10-very good. The collected scores were then normalized to range [0,1][0,1].

This user study encompassed a total of 1893189 * 3 samples with songs that are not present in the train set, including those generated from GDanceR, real dance clips from the dataset, and generated results from our proposed method in neutral mode. Figure.1 shows average scores for all mentioned targets across three experiments. Notably, the ratings of our method are significantly higher than GDanceR across all three criteria. We also perform Tukey honest significance tests to determine the significant differences among the three methods. For the first two criteria (Realism and Music-Motion Correspondence), we observe that the mean scores of all methods are significantly different with p<0.05p < 0.05. For synchronization critera, the differences are significant except for the scores between our method and real dances (p0.07p\approx0.07). This highlights that our method can even achieve comparable scores with real dances, especially in the synchronization evaluation. This can be attributed to the proposed contrastive diffusion strategy, which can effectively maintain a balance between the consistency of the movements and the group/audio context, as well as diversity in generated dances.

Figure. 1. User study results in three criteria: Realism; Music-Motion Correspondence; and Synchronization between Dancers

In the second study, we aim to assess the diversity and consistency of the generated dance outputs and determine if they met the expectations of the users. Specifically, participants were asked to assign scores ranging from 0 to 10 to evaluate the consistency and diversity of each dance clip, i.e., how synchronized or how distinctive movements between dancers does the group dance present. A lower score indicated higher consistency, while a higher score indicated greater diversity. These scores were subsequently normalized to [1,1][-1,1] to align with the studying range of the control parameter employed in our proposed method.

Figure 2 depicts a scatter plot illustrating the relationship between the scores provided by the participants and the γ\gamma parameter that was used to generate the dance samples. The parameter values were randomly drawn from a uniform distribution with range [1,1][-1,1] to create the animations along with randomly sampled musical pieces. The survey shows a strong correlation between the user scores and the control parameter, in which we calculated the correlation coefficient to be approximately 0.88. The results indicate that the diversity and consistency level of the generated group choreography samples is mostly in agreement with the user evaluation, as indicated by the scores obtained.

Figure. 2. Correlation between the controlling consistency/diversity and the scores provided by the users.

3. DISCUSSION AND CONCLUSION

While controlling consistency and diversity in group dance generation by using our proposed GCD has numerous advantages and potentials, there are certain limitations. Firstly, it requires tuning the parameters and a complex system that is not trivial to train, to ensure that the generated dance motions can produce the desired level of similarity among dancers while still presenting enough variation to avoid repetitive or monotonous movements. This may involve long inference processes and may require significant computational resources in both the training and testing phases.

Secondly, over-controlling consistency and diversity may introduce constraints on the creative freedom of generated dances. While enforcing consistency can lead to synchronized and harmonious group movements, it may limit the possibility of exploring unconventional or new experimental dance styles. On the other hand, promoting diversity results in unique and innovative dance sequences, but it may sacrifice coherence and coordination among dancers.

{Although our model can synthesize semantically faithful group dance animation with effective coordination among dancers, it does not capture clear physical contact between dancers such as hand touching. This is because the data we used in training does not contain such detailed hand motion information. We think that exploring group dance with realistic physical hand interactions is a promising area for future work. Additionally, while our method offers a trade-off between diversity and consistency, achieving perfect alignment between high-diversity movements and music remains a challenging task. The diversity level among dancers and the alignment with music are also heavily influenced by the training data. We believe further efforts are required to reach this.}

Lastly, the subjective nature of evaluating consistency and diversity poses a challenge. Metrics for measuring these aspects may not be best fitted. We believe it is essential to consider diverse perspectives and demand domain experts to validate the effectiveness and quality of the generated dance motions.

To conclude, we have introduced GCD, a new method for audio-driven group dance generation that effectively controls the consistency and diversity of generated choreographies. By using contrastive diffusion along with the guidance technique, our approach enables the generation of a flexible number of dancers and long-term group dances without compromising fidelity. Through our experiments, we have demonstrated the capability of GCD to produce visually appealing and synchronized group dance motions. The results of our evaluation, including comparisons with existing methods, highlight the superior performance of our method across various metrics including realism and synchronization. By enabling control over the desired levels of consistency and diversity while preserving fidelity, our work has the potential for applications in entertainment, virtual performances, and artistic expression, advancing the effectiveness of deep learning in generative choreography.

Controllable Group Choreography using Contrastive Diffusion (Part 5)

In previous part, we have discovered the experimental setups and quality analysis. In this part, let investigate consistency and diversity factors in dance generation.

Our code can be found at: https://github.com/aioz-ai/GCD

1. Diversity and Consistency Analysis

Table.1 presents an in-depth analysis of our method's performance across seven evaluation metrics by adjusting the parameter γ\gamma to control the consistency and diversity of the generated group choreographies. The findings reveal that our GCD with high consistency setting (γ=1\gamma = 1), performs better than other settings in terms of MMC, GMC, and TIR metrics, whereas the high diversity setting (γ=1\gamma=-1) achieves better results in the GenDiv metric. Meanwhile, the default model shows the best performance in both realism metrics (FID and GMR). {It can also be seen that the model is relatively robust to the physical plausibility score (PFC) as there are no noticeable differences among the metric in the three settings. This implies that our model is able to create group animation with different consistency or diversity levels without compromising the plausibility of the movements too much.} More interestingly, we found that there are indeed positive correlations between the two measures MMC, GMR, and the trade-off parameter. It is clear that these metrics are better when the consistency level increases. This is reasonable as we expect higher correspondence between the motion and the music (MMC) or higher correlation of the group motions (GMR) when the consistency level grows, which also agrees with the definition of these metrics.

Table. 1. Performance comparison. High Consistency: parameter γ=1\gamma=1; High Diversity: parameter γ=1\gamma=-1; Neutral: parameter γ=0\gamma=0

Consistency setups in GCD lead to more similar movements between dancers. As a result, this similarity contributes to high scores in MMC, GMR, and TIR metrics. In contrast, the diversity setups can synthesize more complex motions with greater variation between dancers as measured by the GenDiv metric, but this also makes it more challenging to reach high values of FID, MMC, and TIR, compared with other setups. In addition, Figure 1 shows an example of correlations between motion beats and music beats under high consistency and diversity settings. The music beats are extracted using the beat tracking algorithm from the Librosa library. Notably, the velocity curves in high consistency setting display relatively similar shapes, whereas in the high diversity scenario, the curves are clearly distinguished among dancers.

Figure. 1. Correlation between the motion and music beats. The solid curve represents the kinetic velocity of each dancer over time and the vertical dashed line depicts the music beats of the sequence. The motion beats can be detected as the local extrema from the kinetic velocity curve.

Despite the greater variations in high diversity setting, we can observe that the generated motions are matched with the music as the music beats are mostly located near the extrema of the motion curves in both settings. The experiment indicates that our model can faithfully capture different aspects of group choreography with different settings, including diversity and synchrony of the motions. This demonstrates the potential of our method towards various dance applications such as dance training or composing. Furthermore, our method can also produce distinctively different animation sequences under the same setting while adhering to the input music.
It is also important to note that all three setups of GCD significantly outperform other baseline models. This verifies the effectiveness of our proposed approach and shows that it can create high-fidelity group dance animations in any setting. For a more detailed visualization of the results, please refer to Figure 2.

Figure. 2. Consistency and diversity trade-off. High Consistency: parameter γ=1\gamma=1; High Diversity: parameter γ=1\gamma=-1; Neutral: parameter γ=0\gamma=0.

2. Number of Dancers Analysis

Table 2 provides insights into results obtained when generating arbitrarily different numbers of dancers using our proposed GCD in the neutral setting. In general, FID, GMR, and GMC metrics do not exhibit a clearly strong correlation with the number of generated dancers but display diverse and varied results. The MMC metric consistently shows its stability across all setups.

Table. 2. Performance of group dance generation methods when we increase the number of generated dancers, compared with GDanceR. In GCD setup, Neutral mode with γ=0\gamma = 0 is used.

As the number of generated dancers increases, the generation diversity (GenDiv) decreases while the trajectory intersection frequency (TIF) increases. However, it is worth noting that the differences observed in these metrics are relatively minor compared to those produced by GDanceR. This implies that our method can effectively control consistency and diversity, significantly reducing the chances of collisions between dancers and maintaining the overall quality of generated group dance motions.

For a detailed visualization of results, please refer to Figure 3. These results underscore the robustness and flexibility of our method in generating group dance motions across varying numbers of dancers while ensuring consistency, diversity, and avoiding collisions between performers.

Figure. 3. Group dance generation results of GCD in terms of different numbers of dancers.

3. Long-term Analysis

To evaluate the efficacy of the guidance signal in GCD for creating long-term group dance sequences, we conducted a comparative analysis between GCD and the baseline model GDanceR. The experiment involved musical pieces of different durations: 15 seconds, 30 seconds and 60 seconds. We show the results with the guidance parameter γ=0.5\gamma = 0.5 to enforce consistency with the music over a long duration. For more detailed information, please refer to Figure 5 and our supplementary video.

Figure. 5. Long-term results of the 60-second clip. For clearer visualization, please visit our demo video.

While both methods produce satisfactory results in the first few seconds of the animations (e.g., about 5-6 seconds), GDanceR starts to exhibit floating and unrealistic movements or freeze into a mean pose in the later period of the sequence. Figure 4 shows the Motion changes comparison between our GCD method and GdanceR. The motion change magnitudes are calculated as average differences of the kinetic features between consecutive frames. It is evident that the motion change magnitude of GDanceR is gradually lower and approaching zero in the later half of the 60-second music piece, whereas our method can preserve high magnitudes and variations over time. This is because GDanceR generates almost frozen dance choreographies during this period. In contrast, the group dance motions produced by GCD remain natural with diverse movements throughout the entire duration of all music samples.

Figure. 4. Motion changes comparison between our proposed GCD and GDanceR. The experiment is conducted on generated group dance results of 60-second music pieces..

These findings confirm that our approach can effectively address the problem of motion generation in long-horizon group dance scenarios. It maintains the motion quality and dynamics of the dance motions, ensuring that the created animations remain visually appealing throughout extended periods. This highlights the advantage of the contrastive strategy to enhance the consistency of the movements of dancers with their group and the music, resulting in significant improvements for long-term dance sequence generation compared to the baseline GDanceR.

Next

In the next post, we will consider Ablation Studies and User Studies.

Controllable Group Choreography using Contrastive Diffusion (Part 4)

In previous part, we have discovered the theory of Music-Motion Transformer. In this part, let investigate the experiment setups for its effectiveness verification, as well as qualyty experimental results.

Our code can be found at: https://github.com/aioz-ai/GCD

1. Experiments

1.1 Implementation Details

The hidden layer of all the MLPs consists of 512512 units followed by GELU activation. The hidden dimension of all attention layers is set to d=512d=512, and the attention adopts a multi-head scheme with 88 attention heads. We also use a feature-wise linear modulation (FiLM) after each attention layer to strengthen the influence of the conditioning context. At the end of each attention block, we append a 22-layer feed-forward network with a feed-forward size of 10241024 to enhance the expressivity of the learned features. We extract the features from the raw audio signal by leveraging the representations from the frozen Jukebox, a pre-trained generative model for music, to enhance the model's generalization ability to several kinds of in-the-wild music. In total, the Group Diffusion Denoising Network is comprised of L=5L=5 stacked Music-Motion Transformer and Group Global Attention blocks, along with 22 transformer encoder layers to encode the music features. We implement the architecture of the Contrastive Encoder similarly to the Denoising Network but without cross attention since it does not take the music as input. The output sequence of the Contrastive Encoder is then averaged out and fed into an output layer with one unit. We also make the Contrastive Encoder aware of the current step in the diffusion chain by appending the diffusion timestep embedding to the motion sequence so that it can provide correct guidance signals in the sampling process. Overall, our model has approximately 62M62M trainable parameters.

1.2 Training

To train the denoising diffusion network, we use the "simple" objective as:

Lsimple=Ex0q(x0c),m[1,M][x0Gθ(xm,m,c)22]\mathcal{L}_{\rm simple} = \mathbb{E}_{x_0\sim q(x_0|c), m\sim [1,M]} \left[\Vert x_0 - \mathcal{G}_\theta(x_m,m,c) \Vert^2_2\right]

To improve the physical plausibility and prevent artifacts of the generated motion, we also utilize auxiliary geometric losses.

Lgeo=λposLpos+λvelLvel+λfootLfoot\mathcal{L}_{\rm geo} = \lambda_{\rm pos}\mathcal{L}_{\rm pos} + \lambda_{vel}\mathcal{L}_{\rm vel} + \lambda_{\rm foot}\mathcal{L}_{\rm foot}

In particular, geometric losses mainly consist of (i) a joint position loss Lpos\mathcal{L}_{\rm pos} to better constrain the global joint hierarchy via forward kinematics; (ii) a velocity loss Lvel\mathcal{L}_{\rm vel} to increase the smoothness and naturalness of the motion by penalizing the difference between the differences between the velocities of the ground-truth and predicted motions; and (iii) a foot contact loss Lfoot\mathcal{L}_{\rm foot} to mitigate foot skating artifacts and improve the realism of the generated motions by ensuring the feet to stay stationary when ground contact occurs.

Our total training objective is the combination of the "simple" diffusion objective, the auxiliary geometric losses, and the contrastive loss:

L=Lsimple+Lgeo+λnceLnce\mathcal{L} = \mathcal{L}_{\rm simple} + \mathcal{L}_{\rm geo} + \lambda_{\rm nce}\mathcal{L}_{\rm nce}

We train our model on 44 NVIDIA V100 GPUs using Adam optimizer with a learning rate of 1e41\mathrm{e}{-4} and a batch size of 6464 per GPU, which took about 77 days for 500500k iterations. The models are trained with M=1000M=1000 diffusion noising steps and a cosine noise schedule. During training, group dance motions are randomly sampled with sequence length T=150T=150 at 3030 Hz, which corresponds to 55-second pieces of music. For geometric losses, the loss weights are empirically set to λpos=1.0\lambda_{\rm pos}=1.0, λsmooth=1.0\lambda_{\rm smooth}=1.0, and λfoot=0.005\lambda_{\rm foot}=0.005, respectively. For the contrastive loss Lnce\mathcal{L}_{\rm nce}, its weight is λnce=0.001\lambda_{\rm nce} = 0.001, the probability of replacing dancers for negative sequences is 0.50.5, and the number of negative samples empirically is selected to 1010.

h~=S(w)hμ(h)σ(h)+b(w)\tilde{h} = S(w) * \frac{h-\mu(h)}{\sigma(h)}+ b(w)

where each channel of the whole activation sequence is first normalized separately by calculating the mean μ\mu and σ\sigma, and then scaled and biased using the outputted affine parameters S(w)S(w) and b(w)b(w). Intuitively, this operation shifts the activated hidden motion features of each individual motion towards a unified group representation to further encourage the association between them. Finally, the output features are then projected back to the original motion dimensions via a linear layer, to obtain the predicted outputs x^0\hat{x}_0.

1.3 Testing

At test time, we use the DDIM sampling technique with 5050 steps to accelerate the sampling speed of the reverse diffusion process. Accordingly, our model can achieve real-time generation at 3030 Hz on a single RTX 2080Ti GPU (excluding the music features extracting step), thanks to the parallelization of the Transformer architecture.

To enable long-term generation, we divide the input music sequence into multiple overlapping chunks, with each chunk having a maximum window size of 55 seconds and overlapped by half with the adjacent chunk. The group dance motions are then generated for each chunk along with the corresponding audio. Subsequently, we merge the outputs by blending the overlapped region between two consecutive chunks using spherical linear interpolation, with the interpolation weight gradually decaying from the current chunk to the next chunk. However, for group choreography synthesis, our model generates dance motions for each dancer in random order. Therefore, we need to establish correspondences between dancers across the chunks (i.e., identifying which one of the NN dancers in the next chunk corresponds to a dancer in the current chunk). To accomplish this, we organize all dancers in the current chunk into one set and the dancers in the next chunk into another set, forming a bipartite graph between the two chunks. We can then utilize the Hungarian algorithm to find the optimal matching, where the Euclidean distance between the two pose sequences serves as the matching weights. Our blending technique is applied at each step of the diffusion sampling process, starting from pure noise, thus it allows the model to gradually denoise the chunks to make them compatible for blending.

2. Experimental Settings

2.1 Dataset

We use AIOZ-GDance dataset in our experiments. AIOZ-GDance is a large-scale group dance dataset including paired music and 3D group motions captured from in-the-wild videos using a semi-automatic method, covering 7 dance styles and 16 music genres.

2.2 Evaluation Protocol

We use the following metrics to evaluate the quality of single dancing motion: Frechet Inception Distance (FID), Motion-Music Consistency (MMC), Generation Diversity (GenDiv), Physical Foot Contact score (PFC). Concretely, FID score measures the realism of individual dance movements against the ground-truth dance. The MMC evaluates the matching similarity between the motion and the music beats, i.e., how well generated dances follow the beat of the music. The generation diversity (GenDiv) is evaluated as the average pairwise distance of the kinetic features of the motions. The PFC evaluates the physical plausibility of the foot movements by calculating the agreement between the acceleration of the character's center of mass and the foot's velocity.

To evaluate the group dance quality, we follow three metrics: Group Motion Realism (GMR), Group Motion Correlation (GMC), and Trajectory Intersection Frequency (TIF).
In general, the GMR measures the realism between generated and ground-truth group motions by calculating Frechet Inception Distance on the extract group motion features. The GMC evaluates the synchrony between dancers within the generated group by calculating their cross-correlation. The TIF measures how often the generated dancers collide with each other in their dance movements.

2.3 Baselines

We compare our GCD method with several recent approaches on music-driven dance generation: FACT, Transflower, and EDGE, all of which are adapted for benchmarking in the context of group dance generation since the original methods were specifically designed for single-dance. We also evaluate against GDanceR, a recent model specifically designed for generating group choreography.

3. Quality Comparison

Table 1 shows a comparison among the baselines FACT, Transflower, EDGE, GDanceR, and our proposed GCD. The results clearly demonstrate that our default model setting with "neutral" mode outperforms the baselines significantly across all evaluations. We also observe that EDGE, a recent diffusion dance generation model, can yield very competitive performance on single-dance metrics (FID, MMC, GenDiv, and PFC). This suggests the advantages of diffusion approaches in motion generation tasks. However, it is still inferior to our model under several group dance metrics, showing the limitations of single dance methods in the context of group dance creation. Experimental results highlight the effectiveness of our approach in generating high-quality group dance motions.

Table. 1. Performance comparison. High Consistency: parameter γ=1\gamma=1; High Diversity: parameter γ=1\gamma=-1; Neutral: parameter γ=0\gamma=0

To complement the quantitative analysis, we present qualitative examples from FACT, GDanceR, and our GCD method in Figure 1. Notably, FACT struggles to deal with the intersection problem, which is reasonable given that it was not originally designed for group dance generation. As a result, the generated motions from FACT lack coordination and synchronization in most cases. While GDanceR shows improvements in terms of motion quality compared to FACT, the generated motions appear floating, unnatural, and sometimes unsynchronized in many cases. These drawbacks indicate that GDanceR's effort on generating group choreography would still require more refinement to produce consistent and cohesive movements among the dancers.

Figure. 1 Comparison between different dance generation methods when generating dancing in groups.

In contrast, our method excels in both controlling consistency and promoting diversity among the generated group dance motions. The outputs from our method demonstrate well-coordinated and realistic movements, implying that it can resolve the challenges of maintaining group coherence while delivering visually appealing results more effectively.

Overall, the conducted quantitative analysis and visual comparisons reaffirm the superior performance of our proposed GCD to generate high-quality, synchronized, and visually pleasing group dance motions.

Next

In the next post, we will mention our Consistency and Diversity Analysis.

Controllable Group Choreography using Contrastive Diffusion (Part 3)

In previous part, we have discovered Music-Motion Transformer. In this part, let investigate the Group Modulation and Contrastive Divergence Loss.

Our code can be found at: https://github.com/aioz-ai/GCD

1. Group Modulation

To better apply the group information constraints to the learned hidden features of the dancers, we adopt a Group Modulation layer that learns to adaptively influence the output of the transformer attention block by applying an affine transformation to the intermediate features based on the group embedding ww. More specifically, we utilize two separate linear layers to learn the affine transformation parameters {S(w);b(w)}Rd\{S(w); b(w)\} \in \mathbb{R}^d from the group embedding ww. The predicted affine parameters are then used to modulate the activations sequence h={h11hT1;;h1NhTN}h = \{h^1_1\dots h^1_T;\dots;h^N_1\dots h^N_T\} as follows:

h~=S(w)hμ(h)σ(h)+b(w)\tilde{h} = S(w) * \frac{h-\mu(h)}{\sigma(h)}+ b(w)

where each channel of the whole activation sequence is first normalized separately by calculating the mean μ\mu and σ\sigma, and then scaled and biased using the outputted affine parameters S(w)S(w) and b(w)b(w). Intuitively, this operation shifts the activated hidden motion features of each individual motion towards a unified group representation to further encourage the association between them. Finally, the output features are then projected back to the original motion dimensions via a linear layer, to obtain the predicted outputs x^0\hat{x}_0.

2.Contrastive Diffusion

We learn the representations that encode the underlying shared information between the group embedding information ww and the group sequence xx. Specifically, we model a density ratio that preserves the mutual information between xx and ww as:

f(x,w)p(xw)p(x)f(x, w) \propto \frac{p({x}|{w})}{p({x})}

f()f(\cdot) is a model (i.e., a neural network) to predict a positive score (how well xx is related to ww) for a pair of (x,w)({x}, {w}).

To enhance the association between the generated group dance (data) and the group embedding (context), we aim to maximize their mutual information with a Contrastive Encoder f(x^,w)f(\hat{x},w) via the contrastive learning objective as in Equation below. The encoder takes both the generated group dance sequence x^\hat{x} and a group embedding ww as inputs, and it outputs a score indicating the correspondence between these two.

Lnce=E[logf(x^,w)f(x^,w)+ΣxjXf(x^j,w)]\mathcal{L}_{\rm nce} = - \mathbb{E} \left[ \log\frac{f(\hat{x},w)}{f(\hat{x},w) + \Sigma_{x^j \in X'}f(\hat{x}^j, w)}\right]

where XX' is a set of randomly constructed negative sequences. In general, this loss is similar to the cross-entropy loss for classifying the positive sample, and optimizing it leads to the maximization of the mutual information between the learned context representation and the data. Using the contrastive objective, we expect the Contrastive Encoder to learn to distinguish between the two quantities: consistency (the positive sequence) and diversity (the negative sequence). This is the key factor that enables the ability to control diversity and consistency in our framework.

Here, we will describe our strategy to construct contrastive samples to achieve our target. Recall that we use reverse distribution pθ(xt1xt)p_\theta(x_{t-1} | x_t) of Gaussian Diffusion with the mean as the prediction of the model while the variance is fixed to a scheduler (Equation~\ref{eq:approximateposterior}). To obtain the contrastive samples, given the true pair is (x0,w)(x_0,w), we first leverage forward diffusion process q(xmx0)q(x_m|x_0) to obtain the noised sample xmx_m. Then, our positive sample is $\theta(x{m-1} |xm, w).Subsequently,weconstructthenegativesamplefromthepositivepairbyrandomlyreplacingdancersfromothergroupdancesequences(. Subsequently, we construct the negative sample from the positive pair by randomly replacing dancers from other group dance sequences (x^j_0 \neq x_0) with some probabilities, feeding it through the forward process to obtain $x^j_m, then our negative sample is $\theta(x^j{m-1} | x^j_m, w). By constructing contrastive samples this way, the positive pair $(x_0,w) represents a group sequence with high consistency, whereas the negative one represents a high diversity sample. This is because mixing a sample with dancers from different groups is likely to result in substantially distinctive movements between each dancer, making it a group dance sample with high degree of diversity. Note that negative sequences should also match the music because they are motions generated by the network whose inputs are manipulated to increase diversity. Particularly, negative samples are acquired from outputs of the denoising network whose inputs are both the current music and the noised mixed group with some replaced dancers. As the network is trained to reconstruct only positive samples, its outputs will likely follow the music. Therefore, negative samples are not just random bad samples but are the valid group dance generated from the network that is trained to generate group dance conditioned on the music. This is because our main diffusion training objective is calculated only for ground-truth dances (positive samples) that are consistent with the music. Our proposed strategy also allows us to learn a more powerful group representation as it directly affects the reverse process, which is beneficial to maintaining consistency in long-term synthesis.

3. Diversity vs. Consistency

Using the Contrastive Encoder f(xm,w)f(x_m,w), we extend the classifier guidance to control the generation process. Accordingly, we incorporate f(xm,w)f(x_m,w) in the contrastive framework to replace the guiding classifier in the original formula, since it provides a score of how consistent the sample is with the group information. In particular, we shift the mean of the reverse diffusion process with the log gradient of the Contrastive Encoder with respect to the generated data as follows:

μ^θ(xm,m)=μθ(xm,m)+γΣθ(xm,m)xmlogf(xm,w)\hat{\mu}_\theta(x_m,m) = \mu_\theta(x_m,m) + \gamma \cdot \Sigma_{\theta}(x_m,m) \nabla_{x_m}\log f(x_m,w)

where γ\gamma is the control parameter that uses the encoder to enforce consistency and connection with the group embedding. Since the Contrastive Encoder is trained to classify between high-consistency and high-diversity samples, its gradients yield meaningful guidance signals to control the trade-off process. Intuitively, a positive value of γ\gamma encourages more consistency between dancers while a negative value (which corresponds to shifting the distribution with a negative gradient step) boosts the diversity between each individual dancer.

Next

In the next post, we will mention Experimental Setups and Analysis.

Controllable Group Choreography using Contrastive Diffusion (Part 2)

In previous part, we have discover the motivation and recent works that focus on generating dances and group dances. We also focus on analyze the pros and cons of each methods. We then dive into the detailed algorithm of our proposed GCD. Our methodology can be splited into three parts: Music-Motion Transformer and Group Global Attention, and Contrastive Diffusion Loss for Controllable Motions. In this part, let investigate the Music-Motion Transformer first.

Our code can be found at: https://github.com/aioz-ai/GCD

1. Group Diffusion Denoising Network

Our model architecture is illustrated in Figure.1. We utilize a transformer based architecture to generate the whole sequence in one go. Compared with recent auto-regressive approach~\cite{le2023music}, our method does not suffer from the error accumulation problem (i.e., the prediction error accumulates over time since the current-frame outputs are used as inputs to the next frame in the auto-regressive fashion) and thus can generate arbitrary long motion dance sequences without freezing effects. The input of our network at each diffusion step mm is the noisy group sequence xm={xm,11,...,xm,T1;...;xm,1N,...,xm,TN}x_m =\{x^1_{m,1},..., x^1_{m,T}; ...;x^N_{m,1},...,x^N_{m,T}\} however, we skip the mm index for ease of notation from now on.

2. Music-Motion Transformer:

Given an input extracted audio sequence a={a1,a2,...,aT}a = \{a_1, a_2, ...,a_T\}, we employ a transformer encoder architecture to encode the music into the sequence of hidden audio representation {c1,c2,...,cT}\{c_1, c_2, ...,c_T\}, which will be used as the conditioning context to the diffusion denoising network. Specifically, we follow the encoder layer which consists of multi-head self-attention layers and feed-forward layers to effectively encode the multi-scale rhythmic patterns and long-term dependencies between music frames. The diffusion time-step mm is also projected to the transformer dimension through a separate Multi-layer Perceptron (MLP) with 3 hidden layers to get the embedding τemb\tau_{emb}, then concatenated with the music feature sequence to obtain the final conditioning context c={c1,c2,...,cT,τemb}c = \{c_1, c_2, ...,c_T, \tau_{emb}\}.

Figure 1. Detailed illustration of our method for group choreography generation.

Although group choreography incorporates the problem of learning the interaction between dancers, we still need to learn the correlation between the dance movements and the accompanying music audio for each dancer. Therefore, we design the Music-Motion Transformer to essentially focus on learning the direct connection between the motion and the music of each individual dancer (and not considering the interconnection among dancers yet). Each frame of the noised input motion xtix^i_t is projected into the transformer dimension by a linear layer followed by an additive positional encoding. Given the whole group sequence including all dancers {x11,...,xT1;...;x1n,...,xTn}\{x^1_1,..., x^1_T; ...;x^n_1,...,x^n_T\}, we separately encode the motion features of each individual dancer by utilizing the multi-head self-attention with masking strategy. We implement the masked self-attention (MSA) mechanism as follows:

MSA(Q,K,V)=softmax(QKdk+mlocal)V,Q=xWQ,K=xWK,V=xWV\text{MSA}(Q,K,V) = \text{softmax}\left(\frac{{Q}{K}^\top}{\sqrt{d_k}} + {m}_{local} \right) {V}, \\ {Q} = x {W}^{Q}, \quad {K} = x {W}^{K}, \quad {V} = x {W}^{V}

where WQ,WKRd×dk{W}^{Q}, {W}^{K} \in \mathbb{R}^{d \times d_k} and WVRd×dv{W}^{V} \in \mathbb{R}^{d \times d_v} are learnable projection matrices to transform the input to query, key, and value, respectively. mlocal{m}_{local} is the local attention mask illustrated in Figure2.a. This mask ensures each individual can only attend to their own motion. Subsequently, to incorporate the music conditioning context cc into each individual motion features, we adopt a transformer decoder architecture with cross-attention mechanism (CA), where the motion is the query and the music is the key/value.

CA(Q~,K~,V~)=softmax(Q~K~dk)V~,Q~=x~W~Q,K~=cW~K,V~=cW~V\text{CA}(\tilde{Q}, \tilde{K}, \tilde{V}) = \text{softmax}\left(\frac{{\tilde{Q}}{\tilde{K}}^\top}{\sqrt{d_k}} \right) {\tilde{V}}, \\ {\tilde{Q}} = \tilde{x} {\tilde{W}}^{Q}, \quad {\tilde{K}} = c {\tilde{W}}^{K}, \quad {\tilde{V}} = c {\tilde{W}}^{V}

where x~\tilde{x} is the output activation of the MSA block, and W~Q{\tilde{W}}^{Q}, W~K{\tilde{W}}^{K}, W~V{\tilde{W}}^{V} are the learnable projection matrices that have similar behavior to the MSA mechanism.

3. Group Global Attention

To ensure the coherency and non-collision in the movements of all dancers within the group, such that their dances should correlate with each other under the music condition instead of dancing asynchronously , we first perform global attention via a masked attention mechanism with a full masking strategy mglobalm_{global}. The attention mask is illustrated in Figure2.b. It allows a dancer to fully attend to all other dancers under the global receptive field. Then, we propose the Group Modulation to enforce the group constraints within the group embedding information.

Figure 2. The local attention mask mlocalm_{local} (a) and global attention mask mglobalm_{global} (b). The blue cell indicates where frames can attend to each other. Blue color represents zero value of the mask while gray color represents minus infinity. x[1:T]ix^i_{[1:T]} indicate motion sequence of ii-th dancer.

Since the synthesized image can be manipulated via a latent style vector, we aim to learn a group embedding information from the input music in order to control the group dance generation process. We first apply temporal average pooling to the encoded music feature sequence to obtain a compact representation of the input music cˉ=1Tt=1Tct\bar{c} = \frac{1}{T}\sum_{t=1}^T c_t. To increase the variation and diversity of the group information (i.e., avoid limiting the group embedding to only one style of the input music), we inject a random noise drawn from a standard gaussian distribution zN(0,I)z \sim \mathcal{N}(0,I) into cˉ\bar{c}. We use an 8-layer MLP to learn a mapping from the audio representation to the group embedding. We also add a learnable embedding token ene_{n} from a variable-size lookup table ERN×DE \in \mathbb{R}^{N\times D} up to NN maximum dancers, to represent the variation of dancers in the sequence since each sequence may contain different number of dancers. In summary, the process can be written as follows:

w=MLP(z+1Tt=1Tct)+en,zN(0,I)w = \text{MLP}\left(z + \frac{1}{T}\sum_{t=1}^T c_t\right) + e_{n}, \quad z \sim \mathcal{N}(0,I)

Next

In the next post, we will mention Group Global Attention.

Controllable Group Choreography using Contrastive Diffusion (Part 1)

Music-driven group choreography poses a considerable challenge but holds significant potential for a wide range of industrial applications. The ability to generate synchronized and visually appealing group dance motions that are aligned with music opens up opportunities in many fields such as entertainment, advertising, and virtual performances. However, most of the recent works are not able to generate high-fidelity long-term motions, or fail to enable controllable experience. In this work, we aim to address the demand for high-quality and customizable group dance generation by effectively governing the consistency and diversity of group choreographies. In particular, we utilize a diffusion-based generative approach to enable the synthesis of flexible {number of dancers} and long-term group dances, while ensuring coherence to the input music. Ultimately, we introduce a Group Contrastive Diffusion (GCD) strategy to enhance the connection between dancers and their group, presenting the ability to control the consistency or diversity level of the synthesized group animation via the classifier-guidance sampling technique. Through intensive experiments and evaluation, we demonstrate the effectiveness of our approach in producing visually captivating and consistent group dance motions. The experimental results show the capability of our method to achieve the desired levels of consistency and diversity, while maintaining the overall quality of the generated group choreography.

Our code can be found at: https://github.com/aioz-ai/GCD

1. Introduction

With the widespread presence of digital social media platforms, the act of creating and editing dance videos has gained immense popularity among social communities. This surge in interest has resulted in the daily production and watching of millions of dancing videos across online platforms. Recently, researchers from computer vision, computer graphics, and machine learning communities have devoted considerable attention to developing techniques that can generate natural dance movements from music. These advancements have far-reaching implications and find applications in various domains, such as animation, the creation of virtual idols, the development of virtual meta-verse, and dance education. These techniques empower artists, animators, and educators alike, providing them with powerful tools to enhance their creative endeavors and enrich the dance experience for both performers and audiences.

While significant progress has been made in generating dancing motions for single dancer, the task of producing cohesive and expressive choreography for a group of dancers has received limited attention. The generation of synchronized group dance motions that are both realistic and aligned with music remains a challenging problem in the field of computer animation and motion synthesis. This is primarily due to the complex relationship between music and human motion, the diverse range of motions required for group performances, and the insufficient of a suitable dataset. At present, AIOZ-GDance stands as the most recent extensive dataset available to facilitate the task of generating group choreography. Besides, while current algorithms can generate individual movements and choreographic sequences, ensuring that these elements align seamlessly with the overall group performance is also paramount.

Different from solo dance, group dance involves coordination and interaction between dancers, making it crucial and challenging to establish correlations between motion series within a group. Besides, group dance can involve complex and diverse choreographies among participating dancers while still maintaining a semantic relationship between the motion and input music. Exploring the consistency and diversity between the movements of dancers of the synthesized group choreography is of vital importance to create a natural and expressive performance. The ability to control the consistency and diversity in group dance generation holds great potential across various applications. One such application is in the realm of entertainment and performance. Choreographers and creative teams can leverage this control ability to design captivating group dance routines that seamlessly blend synchronized movements with moments of individual expression. Second, in the context of animation and virtual metaverses, the control over consistency and diversity allows for the creation of visually stunning and immersive virtual dance performances. By balancing the synchronization of dancers' movements, while also introducing variations and unique flourishes, the generated group dances can captivate audiences and evoke a sense of realism and authenticity. Last but not least, in dance education and training, the ability to regulate consistency and diversity in group dance generation can be invaluable. It enables instructors to provide students with a diverse range of generated dance routines and samples that challenge their abilities, promote collaboration, and foster creativity. By dynamically adjusting the level of consistency and diversity, educators can cater to the unique need and skill level of each individual dancer, creating more inclusive instructions and enriching the learning environment. Although plenty of applications can be listed, due to some limitations of data establishment, investigating the consistency and diversity in group choreography has not been carefully explored.

In this paper, our goal is to develop a controllable technique for group dance generation. We present a Group Contrastive Diffusion (GCD) strategy that learns an encoder to capture the key targets between group dance movements. Diffusion modeling provides a flexible framework for manipulating the dance distribution, which allows us to modulate the degree of diversity and consistency in the generated dances. By using denoising diffusion probabilistic model as a key technique, we can effectively control the trade-off between diversity and consistency during the group dance generation, thanks to the guided sampling process. With this approach, we can guide the generation process toward a desired balance between diversity and consistency levels. Moreover, incorporating the encoder, which learns the association between the dancers and their group, can help to maintain the generated dance moves so that they are consistent with a specific dance style, music genre, or any long-term chorus. We empirically show that this approach has the potential to enhance the quality and naturalness of generated group dance performances, making it more appealing for various applications.

Figure 1. We present a contrastive diffusion method that controls the consistency (top row) and diversity (second row) in group choreography.

2. Overview

Music-driven Choreography. Creating natural and authentic human choreography from music is a complex task. One commonly employed technique involves using a motion graph derived from a vast motion database to generate new motions. This involves combining various motion segments and optimizing transition costs along the graph path. Alternatively, there are other methods that incorporate music-motion similarity matching constraints to ensure consistency between the motion and the accompanying music. Previous studies have extensively explored these methodologies. However, most of these approaches relied on heuristic algorithms to stitch together pre-existing dance segments sourced from a limited music-dance database. While these methods are successful in generating extended and realistic dance sequences, they face limitations when trying to create entirely novel dance fragments.

In recent years, several signs of progress have been made in the field of music-to-dance motion generation using Convolutional Network (CNN), Recurrent Network (RNN), Graph Neural Network (GNN), Generative Adversarial Network (GAN), or Transformer. Typically, these methods rely on multiple inputs such as the current music and a brief history of past dance movements to predict the future sequence of human poses. Recently, Gong et. al. propose an interesting task of generating dance by simultaneously utilizing both music and text instruction. A music-text feature fusion module was designed to fuse the inputs into a motion decoder to generate dance conditioned on both music and text. However, although these methods have the potential to produce natural and realistic dancing motion, they are often unable to create synchronized and harmonious movements between multiple dancers. Ensuring coordination and synchronization between dancers is a complicated problem, as it involves not only individual pose predictions but also the seamless integration of these poses within the context of a group. Achieving synchronized and harmonious group movements requires considering spatial and temporal relationships among dancers, their interactions, and the overall choreographic structure. Thus, further advances in the field are considered to address these challenges, including works that use deep learning approaches such as Variational Autoencoder (VAE), GAN, and Normalising Flow.

Unfortunately, most of these networks are limited by their ability to model long-term dance sequences ce may freeze or drift towards the end of the music. To facilitate long-term generation, many researchers apply a motion repeat constraint to predict future frames by attending to the historical motions. Nevertheless, this would limit the flexibility of the model by forcing it to always look into the past.

Group Choreography Group choreography and its related problem, multi-person motion prediction, have been an active research area with numerous studies addressing the challenges of predicting the behaviors of multiple individuals. Alahi et.al. utilize a Markov chain model to jointly analyze the trajectories of several pedestrians and predict their destinations in a given scene. Adel et. al. integrate social interactions and the visual context of the environment to forecast the future motion of multiple individuals. Multi-Range Transformers has proposed to predict the movements of groups with more than ten people engaging in social interactions. These aforementioned methods leverage various techniques to capture social interaction, spatial dependencies, and temporal dynamics, generally aiming to predict accurate and socially plausible future motions for multiple individuals in different scenarios. However, despite the notable advancements achieved, there remains a demand for further investigation of the correlation between the consistency and diversity of motions within group context. A deeper understanding of how to attain the optimal balance between consistency and diversity holds the potential to unlock new possibilities for creating group choreographies that benefit the users in many circumstances.

Diffusion for Music-driven Choreography.Recently, diffusion-based approaches have shown remarkable results on several generative tasks ranging from image generation, audio synthesis, pose estimation, natural language generation, and motion synthesis, to point cloud generation, 3D object synthesis, and scene creation. Diffusion models have shown that they can achieve high mode coverage, unlike GANs, while still maintaining high sample quality. Most existing diffusion-based approaches for human motion/dance synthesis only focus on generating motion sequences for a single character, conditioned on information such as text, audio, or both audio and text. Different from these prior works, we aim to create group of dancing motions from music, which includes coordinating multiple characters, avoiding collisions, and maintaining coherence between them. In addition to the vanilla diffusion loss term used for training in previous works, our method employs a contrastive learning strategy that directly influences the training of the diffusion reverse process, enhancing the association within the dance group.

A prominent issue of the diffusion approach for motion synthesis is that although it is highly effective in generating diverse samples, injecting a large amount of noise during the sampling process can lead to inconsistent results. This issue is particularly problematic for group dance paradigm. Therefore, our desideratum is to design a group dance generation model with the ability to address both diversity and consistency problem.

3. Preliminaries

3.1 Background

Given an input music sequence {a1,at,...,aT}\{a_1, a_t, ...,a_T\} with t={1,...,T}t = \{1,..., T\} indicates the index of the music frames, our goal is to generate the group motion sequences of NN dancers: {x11,...,xT1;...;x1N,...,xTN}\{x^1_1,..., x^1_T; ...;x^N_1,...,x^N_T\} where xtix^i_t is the pose of ii-th dancer at frame tt. We represent dances as sequences of poses in the 24-joint of the SMPL model, using the 6D continuous rotation for every joint, along with a single 3D root translation. This rotation representation ensures the uniqueness and continuity of the rotation vector, which is more beneficial to the training of deep neural networks. We tackle the group dance generation task by using a diffusion-based framework to synthesize the motions from a random noise distribution, given the music conditioning. Thanks to the sampling process of the diffusion model, we can effectively control the consistency and diversity in the generated sequences.

3.2 Forward Process of Diffusion Model.

Given an original sample from the real data distribution x0q(x0){x_0} \sim q({x_0}), following \cite{ho2020ddpm}, the forward diffusion process is defined as a Markov process that gradually adds Gaussian noise to the data under a pre-defined noise schedule up to MM steps.

q(xmxm1)=N(xm;1βmxm1,βmI),m{1,2,...,M}q(x_m | x_{m-1}) = \mathcal{N}(x_m; \sqrt{1-\beta_m} {x_{m-1}}, \beta_mI), \forall m \in \{1,2, ... ,M\}

If the noise variance schedule βm\beta_m is small and the number of diffusion step MM is large enough, the distribution q(xM)q(x_M) at the end of the process is well-approximated by a standard normal distribution N(0,I)\mathcal{N}(0,I), which is easy to sample from. Thanks to the nice property of the forward diffusion, we can directly obtain the noised sample at any arbitrary step mm without traversing through the whole chain:

q(xmx0)=N(xm;αˉmx0,(1αˉm)I),xm=αˉmx0+1αˉmϵ,ϵN(0,I)q(x_m | x_{0}) = \mathcal{N}(x_m; \sqrt{\bar{\alpha}_m} x_{0} , (1-\bar{\alpha}_m)I), x_m = \sqrt{\bar{\alpha}_m} x_{0} + \sqrt{1-\bar{\alpha}_m} \epsilon, \epsilon \sim \mathcal{N}(0, I)

where αm=1βm\alpha_m = 1-\beta_m and αˉm=s=0mαs\bar{\alpha}_m = \prod_{s=0}^m \alpha_s.

3.3 Reverse Process.

By additionally conditioning on x0x_0, the posterior of the reverse process is tractable and becomes a Gaussian distribution:

q(xm1xm,x0)=N(xm1;μ~m,β~mI),q(x_{m-1} | x_m, x_0) = \mathcal{N} (x_{m-1}; \tilde{\mu}_m, \tilde{\beta}_mI),

where μ~m\tilde{\mu}_m and β~m\tilde{\beta}_m are the posterior mean and variance that depend on both xmx_m and x0x_0, respectively. To obtain a sample from the original data distribution, we start by sampling from the noise distribution q(xM)q(x_M) and then gradually remove the noise until we reach x0x_0, following the reverse process. Therefore, our goal is to train a neural network to approximate the posterior q(xm1xm)q(x_{m-1} | x_m) of the reverse process as:

pθ(xm1xm)=N(xm1;μθ(xm,m),Σθ(xm,m))p_{\theta}(x_{m-1} | x_m) = \mathcal{N}(x_{m-1}; \mu_{\theta}(x_m, m), \Sigma_{\theta}(x_m,m) )

We model only the mean μθ(xm,m)\mu_{\theta}(x_m, m) of the reverse distribution while keeping the variance Σθ(xm,m)\Sigma_{\theta}(x_m,m) fixed according to the noise schedule. However, instead of predicting the noise ϵm\epsilon_m at any arbitrary step mm as in their approach, we train the network to learn to predict the original noiseless signal x0x_0. The sample at the previous step m1m-1 can be obtained by noising back the predicted x0x_0. For conditional generation setting, the network is additionally conditioned with the conditioning signal cc as x0Gθ(xm,m,c)x_0 \approx \mathcal{G}_\theta(x_m,m,c) with model parameters θ\theta.

Next

In the next post, we will mention our proposal in details.

Reducing Training Time in Cross-Silo Federated Learning using Multigraph Topology (Part 4)

In previous post, we have mentioned multigraph parsing proccess, how to train a multigraph under decentralized federated learning, and experimental setups for Multigraph. In the post, we will mention the effectiveness and efficiency of multigraph topology design under different configurations.

Our code can be found at: https://github.com/aioz-ai/MultigraphFL

1. Cycle Time Comparison

Table 1 shows the cycle time of our method in comparison with other recent approaches. This table illustrates that our proposed method significantly reduces the cycle time in all setups with different networks and datasets. In particular, compared to the state-of-the-art RING, our method reduces the cycle time by 2.182.18, 1.51.5, 1.741.74 times in average in the FEMNIST, iNaturalist, Sentiment140 dataset, respectively. Our method also clearly outperforms MACHA, MACHA(+), and MST by a large margin. The results confirm that our multigraph with isolated nodes helps reduce the cycle time in federated learning.

From Table1, our multigraph achieves the minimum improvement under the Amazon network in all three datasets. This can be explained that, under the Amazon network, our proposed topology does not generate many isolated nodes. Hence, the improvement is limited. Intuitively, when there are no isolated nodes, our multigraph will become the overlay, and the cycle time of our multigraph will be equal to the cycle time of the overlay in RING.

Tab-1

2. Isolated Node Analysis

Isolated Nodes vs. Network Configuration. The numbers of isolated nodes vary based on the network configuration (Amazon, Gaia, Exodus, etc.). The parameter tt (maximum number of edges between two nodes), and the delay time which is identified by many factors (geometry distance, model size, computational cost based on tasks, bandwidth, etc also affect the process of generating isolated nodes. Table 2 illustrates the effectiveness of isolated nodes in different network configurations. Specifically, we conduct experiments on the FEMNIST dataset using five network configurations (Gaia, Amazon, Geant, Exodus, Ebone). We can see that our cycle time compared with RING is reduced significantly when more communication rounds or graph states have isolated nodes. Tab-2

Table 2: The effectiveness of isolated nodes under different network configurations. All experiments are trained with 6400 communication rounds on FEMNIST dataset.We then record the number of states and rounds that have the appearance of isolated nodes and compare our cycle time with RING.

Isolated Nodes vs. RING vs. Random Strategy. Isolated nodes play an important role in our method as we can skip the model aggregation step in the isolated nodes. In practice, we can have a trivial solution to create isolated nodes by randomly removing some nodes from the overlay of RING. Table 3 shows the experiment results in two scenarios on FEMNIST dataset and Exodus Network: i} Randomly remove some silos in the overlay of RING, and ii} Remove the most inefficient silos (i.e., silos with the longest delay) in the overlay of RING. From Table 3, the cycle time reduces significantly when two aforementioned scenarios are applied. However, the accuracy of the model also drops significantly. This experiment shows that although randomly removing some nodes from the overlay of RING is a trivial solution, it can not maintain model accuracy. On the other hand, our multigraph not only reduces the cycle time of the model but also preserves the accuracy. This is because our multigraph can skip the aggregation step of the isolated nodes in a communication round. However, in the next round, the delay time of these isolated nodes will be updated, and they can become normal nodes and contribute to the final model.

Tab-3

Table 3: The cycle time and accuracy of our multigraph vs. RING with different criteria.

Isolated Nodes Illustration. Figure belows shows a detailed illustration of our algorithm with the isolated nodes in a real-world training scenario. The experiment is conducted on Gaia network geometry and their corresponding hardware for supporting link latency computation. The image classification task is chosen for this benchmarking by using FEMNIST dataset and CNN backbone provided by Marfod \etal. Hence, we keep the model transmitted size at 4.624.62 Mb, all access links have 1010 Gbps traffic capacity, the number of local updates is set to 11, and the maximum number of edges tt is set to 33. As shown in this Figure, although there are no isolated nodes in the initialized state, the number of isolated nodes increases in the next consequence states with a vast number (4 nodes). This circumstance leads to a 4\sim 4 times reduction in cycle time compared to the initialized state. The appearance of isolated nodes also greatly reduces the connection between silos by 3.6\sim 3.6 times, from 1111 down to 33 connections, and also discarded ones all have high latency.

Fig-1

3. Multigraph Ablation Study

Accuracy Analysis. In federated learning, improving the model accuracy is not the main focus of topology designing methods. However, preserving the accuracy is also important to ensure model convergence. Table 4 shows the accuracy of different topologies after 6,4006,400 communication training rounds on the FEMNIST dataset. This table illustrates that our proposed method achieves competitive accuracy with other topology designs. This confirms that our topology can maintain the accuracy of the model, while significantly reducing the training time.

Tab-4

Table 4: Accuracy comparison between different topologies. The experiment is conducted using the FEMNIST dataset. The accuracy is reported after 6,4006,400 communication rounds in all methods.

Convergence Analysis. Figure belows shows the training loss versus the number of communication rounds and the wall-clock time under Exodus network using the FEMNIST dataset. This figure illustrates that our proposed topology converges faster than other methods while maintaining the model accuracy. We observe the same results in other datasets and network setups.

Cycle Time and Accuracy Trade-off. In our method, the maximum number of edges between two nodes tt mainly affects the number of isolated nodes. This leads to a trade-off between the model accuracy and cycle time. Table 5 illustrates the effectiveness of this parameter. When t=1t = 1, we technically consider there are no weak connections and isolated nodes. Therefore, our method uses the original overlay from RING. When tt is set higher, we can increase the number of isolated nodes, hence decreasing the cycle time. In practice, too many isolated nodes will limit the model weights to be exchanged between silos. Therefore, models at isolated nodes are biased to their local data and consequently affect the final accuracy.

Tab-5

Table 5: Cycle time and accuracy trade-off with different value of tt, i.e., the maximum number of edges between two nodes.

4. Conclusion

We proposed a new multigraph topology for cross-silo federated learning. Our method first constructs the multigraph using the overlay. Different graph states are then parsed from the multigraph and used in each communication round. Our method significantly reduces the cycle time by allowing the isolated nodes in the multigraph to do model aggregation without waiting for other nodes. The intensive experiments on three datasets show that our proposed topology achieves new state-of-the-art results in all network and dataset setups.

Reducing Training Time in Cross-Silo Federated Learning using Multigraph Topology (Part 3)

In previous part, we have investigated that how delay time and cycle time is affected by the modification of multigraph in the design of the topology. Also, we will explore how multigraph can be constructed. In this post, we will mention multigraph parsing proccess, how to train a multigraph under decentralized federated learning, and experimental setups for Multigraph.

Our code can be found at: https://github.com/aioz-ai/MultigraphFL

1. Multigraph Parsing

In Algorithm~\ref{alg:state_form}, we parse multigraph Gm\mathcal{G}_m into multiple graph states Gms\mathcal{G}_m^s. Graph states are essential to identify the connection status of silos in a specific communication round to perform model aggregation. In each graph state, our goal is to identify the isolated nodes. During the training, isolated nodes update their weights internally and ignore all weakly-connected edges that connect to them.

To parse the multigraph into graph states, we first identify the maximum of states in a multigraph smaxs_{\max} by using the least common multiple (LCM). We then parse the multigraph into smaxs_{\max} states. The first state is always the overlay since we want to make sure all silos have a reliable topology at the beginning to ease the training. The reminding states are parsed so there is only one connection between two nodes. Using our algorithm, some states will contain isolated nodes. During the training process, only one graph state is used in a communication round. Figure below illustrates the training process in each communication round using multiple graph states.

2. Multigraph Training

In each communication round, a state graph Gms\mathcal{G}_m^s is selected in a sequence that identifies the topology design used for training. We then collect all strongly-connected edges in the graph state Gms\mathcal{G}_m^s in such a way that nodes with strongly-connected edges need to wait for neighbors, while the isolated ones can update their models. We train our multigraph with DPASGD algorithm:

wi(k+1)={jNi++{i}Ai,jwj(kh),k0&Ni++>1,wi(k)αk1bh=1bLi(wi(k),ξi(h)(k)),otherwise.w_{i}\left(k + 1\right) = \begin{cases} \sum_{j \in \mathcal{N}_{i}^{++} \cup{\{i\}}}A_{i,j}w_{j}\left(k - h\right), \forall k \equiv 0 \& \left|\mathcal{N}_{i}^{++}\right| > 1 ,\\ w_{i}\left(k\right)-\alpha_{k}\frac{1}{b}\sum^b_{h=1}\nabla L_i\left(w_{i}\left(k\right),\xi_i^{\left(h\right)}\left(k\right)\right), otherwise. \end{cases}

where (kh)(k- h) is the index of the considered weights; hh is initialized to 00 and h=h+1ekh(i,j)=0h = h + 1 \forall e_{k-h}(i,j) = 0. Through Equation above, at each state, if a silo is not an isolated node, it must wait for the model from its neighbor to update its weight. If a silo is an isolated node, it can use the model in its neighbor from the (kh)(k-h) round to update its weight immediately. The training procedure is described as below:

3. Algorithm Complexity

It is trivial to see that the complexity of training procedure is O(n2)\mathcal{O}(n^2). In practice, since the cross-silo federated learning setting has only a few hundred silos (n<500n<500), the time to execute our algorithms is just a tiny fraction of training time. Therefore, our proposed topology still can significantly reduce the overall wall-clock training time.

4. Experimental Setups

Datasets. We use three datasets in our experiments: Sentiment140, iNaturalist, and FEMNIST. All datasets and the pre-processing process are conducted by following recent works. Table below shows the dataset setups in detail.

Network. We consider five distributed networks in our experiments: Exodus, Ebone, Géant, Amazon and Gaia. The Exodus, Ebone, and Géant are from the Internet Topology Zoo. The Amazon and Gaia network are synthetic and are constructed using the geographical locations of the data centers.

Baselines. We compare our multigraph topology with recent state-of-the-art topology designs for federated learning: STAR, MATCHA, MATCHA(+), MST, and RING.

Hardware Setup. Since measuring the cycle time is crucial to compare the effectiveness of all topologies in practice, we test and report the cycle time of all baselines and our method on the same NVIDIA Tesla P100 16Gb GPUs. No overclocking is used.

Time Simulator. We adapted PyTorch with the MPI backend to run DPASGD and DPASGD++ on a GPU cluster. We take advantage of the network simulator, the Time Simulator, which uses an arbitrary topology and computation times of silos as input to calculate the time instants at which local models are computed. The wall-clock time is reconstructed by this time simulator needs thorough understanding of the topology, including all factors mentioned in Delay Equations in each network configuration. The related configuration information is already provided in GAIA Network, and the simulator is created by Marfod \etal.

Next

In the next post, we will mention the effectiveness and efficiency of multigraph topology design under different configurations.

Reducing Training Time in Cross-Silo Federated Learning using Multigraph Topology (Part 2)

In previous paart, we how explore decentralized federated learning, how and why multigraph is proposed to improve training process. In this part, we will investigate that how delay time and cycle time is affected by the modification of multigraph in the design of the topology. Also, we will explore how multigraph can be constructed.

Our code can be found at: https://github.com/aioz-ai/MultigraphFL

1. Delay time in multigraph

A delay to an edge e(i,j)e(i, j) is the time interval when node jj receives the weight sending by node ii, which can be defined by:

d(i,j)=u×Tc(i)+l(i,j)+MO(i,j),d(i,j) = u \times T_c(i) + l(i,j) + \frac{M}{O(i,j)},

where Tc(i)T_{c}(i) denotes the time to compute one local update of the model; uu is the number of local updates; l(i,j)l(i,j) is the link latency; MM is the model size; O(i,j)O(i, j) is the total network traffic capacity. However, unlike other communication infrastructures, the multigraph only contains connections between silos without other nodes such as routers or amplifiers. Thus, the total network traffic capacity O(i,j)=min(CUP(i)Ni,CDN(j)Ni+)O(i,j) = \text{min}\left(\frac{C_{\rm{UP}}(i)}{\left|{\mathcal{N}_{i}^{-}}\right|}, \frac{C_{\rm{DN}}(j)}{\left|\mathcal{N}_{i}^{+}\right|}\right) where CUPC_{\rm{UP}} and CDNC_{\rm{DN}} denote the upload and download link capacity. Note that the upload and download processes can happen in parallel.

Since multigraph can contain multiple edges between two nodes, we extend the definition of the delay in the previous equation to dk(i,j)d_k(i,j), with kk is the kk-th communication round during the training process, as:

dk+1(i,j)={dk(i,j),if ek+1(i,j)=1 and ek(i,j)=1max(u×Tc(j),dk(i,j)dk1(i,j)),if ek+1(i,j)=1 and ek(i,j)=0τk(Gm)+dk1(i,j)),if ek+1(i,j)=0 and ek(i,j)=1τk(Gm),otherwised_{k+1}(i,j) = \begin{cases} d_k(i,j), \\\qquad \qquad \text{if } e_{k+1}(i,j) = 1\text{ and }e_{k}(i,j) = 1\\ \text{max}( u \times T_c(j),d_{k}(i,j) - d_{k-1}(i,j)), \\\qquad\qquad\text{if }e_{k+1}(i,j) = 1\text{ and }e_{k}(i,j) = 0\\ \tau_k(\mathcal{G}_m) + d_{k-1}(i,j)), \\\qquad\qquad\text{if } e_{k+1}(i,j) = 0\text{ and }e_{k}(i,j) = 1\\ \tau_k(\mathcal{G}_m), \\\qquad\qquad\text{otherwise} \end{cases}

where e(i,j)e(i,j)==0\mathbb{0} is weakly-connected edge, e(i,j)e(i,j)==1\mathbb{1} is strongly-connected edge; $ \tau_k(\mathcal{G}_m)$ is the cycle time at the kk-th computation round during the training process.

In general, using introduced equation, \textit{the delay of the next communication round dk+1d_{k+1} is updated based on the delay of the previous rounds} and other factors, depending on the edge type connection.

2. Cycle time in multigraph

The cycle time per round is the time required to complete a communication round. In this work, we define the cycle time per round is the maximum delay between all silo pairs with strongly-connected edges. Therefore, the average cycle time of the entire training is:

τ(Gm)=1kk=0k1(maxjNi++{i},iV(dk(j,i))),\tau(\mathcal{G}_m) =\frac{1}{k }\sum^{k-1}_{k=0} \left(\underset{j \in \mathcal{N}^{++}_{i} \cup\{i\}, \forall i \in \mathcal{V}}{\text{max}} \left(d_k\left(j,i\right)\right)\right),

where Ni++\mathcal{N}_{i}^{++} is an in-neighbors silo set of ii whose edges are strongly-connected.

3. Multigraph Construction

Algorithm 1 describes our methods to generate the multigraph Gm\mathcal{G}_m with multiple edges between silos. The algorithm takes the overlay Go\mathcal{G}_o as input. Similar to~\cite{marfoq2020throughput}, we use the Christofides algorithm to obtain the overlay. In Algorithm 1, we establish multiple edges that indicate different statuses (strongly-connected or weakly-connected). To identify the total edges between a silo pair, we divide the delay d(i,j)d(i,j) by the smallest delay dmind_{\min} overall silo pairs, and compare it with the maximum number of edges parameter tt (t=5t=5 in our experiments). \textit{We assume that the silo pairs with longer delay will have more weakly-connected edges, hence potentially becoming isolated nodes}. Overall, we aim to increase the number of weakly-connected edges, which generate more isolated nodes to speed up the training process. Note that, from Algorithm 1, each silo pair in the multigraph should have one strongly-connected edge and multiple weakly-connected edges. The role of the strongly-connected edge is to make sure that two silos have a good connection in at least one communication round.

Next

In the next post, we will mention multigraph parsing proccess and how to train a multigraph under decentralized federated learning.

Reducing Training Time in Cross-Silo Federated Learning using Multigraph Topology (Part 1)

Federated learning is an active research topic since it enables several participants to jointly train a model without sharing local data. Currently, cross-silo federated learning is a popular training setting that utilizes a few hundred reliable data silos with high-speed access links to training a model. While this approach has been widely applied in real-world scenarios, designing a robust topology to reduce the training time remains an open problem. In this paper, we present a new multigraph topology for cross-silo federated learning. We first construct the multigraph using the overlay graph. We then parse this multigraph into different simple graphs with isolated nodes. The existence of isolated nodes allows us to perform model aggregation without waiting for other nodes, hence effectively reducing the training time. Intensive experiments on three public datasets show that our proposed method significantly reduces the training time compared with recent state-of-the-art topologies while maintaining the accuracy of the learned model.

Our code can be found at: https://github.com/aioz-ai/MultigraphFL

1. Introduction

Federated learning involves training models using remote devices or isolated data centers while keeping the data localized to respect user privacy policies. According to available literature, there are two prominent training scenarios: the "cross-device" scenario, which includes numerous unreliable edge devices with limited computational capacity and slow connection speeds, and the "cross-silo" scenario, which features a smaller number of reliable data silos with powerful computing resources and high-speed access links. Recently, the cross-silo scenario has gained traction in various federated learning applications.

In practical terms, federated learning represents a promising research avenue that allows us to harness the capabilities of machine learning techniques while upholding user privacy. Key obstacles in federated learning encompass issues like model convergence, communication bottlenecks, and disparities in data distributions across different silos. A commonly employed federated training approach involves establishing a central node responsible for overseeing the training process and aggregating contributions from all clients. However, a drawback of this client-server approach is the potential for communication bottlenecks, especially when dealing with a large number of clients. To mitigate this limitation, recent research has explored the concept of decentralized or peer-to-peer federated learning, where communication occurs via a peer-to-peer network topology, eliminating the need for a central node. Nevertheless, a major challenge in decentralized federated learning remains achieving rapid training while ensuring model convergence and preserving model accuracy.

In federated learning, the structure of communication networks holds significant importance. Specifically, an efficient network design contributes to quicker convergence, resulting in reduced training duration and energy consumption, as measured by worst-case convergence bounds within the topology's framework. Additionally, the topology's design has direct implications for other training-related challenges, including network congestion, overall model accuracy, and energy efficiency. The development of a resilient network structure capable of minimizing training time while preserving model accuracy remains an ongoing challenge in federated learning. Our paper is dedicated to devising a novel network design tailored for cross-silo federated learning, a prevalent scenario in practical applications.

Figure 1. We conducted a comparative analysis of various network structures using the FEMNIST dataset and the Exodus network. After completing 6,400 communication rounds, we measured and reported both the accuracy and the total wall-clock training time (or overhead time). Notably, our approach resulted in a substantial reduction in training duration while upholding model accuracy.

Lately, various network configurations have emerged for cross-silo federated learning. For instance, the STAR topology involves an orchestrator averaging all models during each communication round. Another approach, known as MATCHA, divides potential communications into pairs of clients, with random selection for model transmission in each round. Additionally, the RING topology employs max-plus linear systems. Despite progress in this field, challenges persist, including access link congestion, straggler effects, and the establishment of diverse topologies across communication rounds.

In this paper, we introduce a novel multigraph topology inspired by recent advancements in federated learning. Our aim is to enhance the efficiency of cross-silo federated learning. Our approach involves constructing a multigraph based on the overlay of existing network topologies. Subsequently, we decompose this multigraph into simpler graphs, each featuring only a single edge connecting two nodes. These individual graphs are referred to as "states" within the multigraph. Importantly, each state can involve isolated nodes that perform model aggregation independently, reducing the cycle time in each communication round significantly. Our intensive experiments demonstrate that our proposed topology outperforms existing state-of-the-art methods by a wide margin in terms of training time for cross-silo federated learning, as illustrated in Figure.1.

2. Overview

Federated Learning is recognized for its capacity to safeguard data privacy. In its modern incarnation, federated learning adopts a centralized network design, where a central node collects gradients from client nodes to update a global model. Early contributions in federated learning research include pioneering work and seminal papers by various researchers. Subsequent extensions and developments in federated learning and related distributed optimization algorithms have been proposed. Federated Averaging (FedAvg), initially introduced by one group, has inspired variations and other recent state-of-the-art model aggregation techniques, addressing convergence and the non-IID (non-identically and independently distributed) data challenge. Despite its simplicity, the client-server approach faces communication and computational bottlenecks at the central node, particularly when dealing with a large number of clients.

Decentralized Federated Learning flips the traditional federated learning model, enabling direct interactions between siloed data nodes, eliminating the necessity for a central coordinating node. While this approach mitigates communication congestion at a central point, optimizing a fully peer-to-peer network presents substantial challenges. The decentralized periodic averaging stochastic gradient descent method has demonstrated convergence rates comparable to centralized algorithms, making large-scale model training feasible. Furthermore, previous research has conducted systematic analyses of decentralized federated learning. A recent advancement involves leveraging a knowledge distillation mechanism to facilitate collaboration among silos in decentralized federated scenarios while preserving privacy among neighboring nodes.

Communication Topology plays a fundamental role in influencing the complexity and convergence behavior of federated learning. Numerous efforts have been dedicated to improving the efficiency of communication topologies, including star-shaped topologies and optimized-shaped topologies. In particular, a spanning tree topology has been introduced to reduce training time.

The STAR topology is designed for orchestrating the averaging of model updates in each communication round. Meanwhile, the MATCHA approach focuses on accelerating the training process through decomposition sampling. Recognizing the impact of straggler effects on communication round duration, methods for selecting the degree of a regular topology have been explored.

The RING topology is tailored for cross-silo federated learning and leverages the principles of max-plus linear systems. A sample-induced topology has been introduced, capable of effectively recovering the performance of existing SGD-based algorithms and their corresponding convergence rates. In a recent comprehensive survey, various models, frameworks, and algorithms related to network topologies in federated learning have been explored.

Multigraph is a concept that originates from traditional mathematics. In conventional terms, a "graph" typically denotes a simple graph without loops or multiple edges between two nodes. In contrast, a multigraph allows for the presence of multiple edges between two nodes. In the realm of deep learning, multigraphs have found utility across various domains, including clustering, medical image processing, traffic flow prediction, activity recognition, recommendation systems, and cross-domain adaptation. In this research, we employ a multigraph construction to facilitate isolated nodes and expedite training in cross-silo federated learning.

3. Preliminaries

3.1 Federated Learning

In federated learning, silos do not share their local data, but still periodically transmit model updates between them. Given NN siloed data centers, the objective function for federated learning is:

minwRdi=1NpiEξi[Li(w,ξi)],\min_{\textbf{w} \in \mathbb R^d} \sum^{N}_{i=1}p_i E_{\xi_i}\left[ L_{i}\left(\textbf{w}, \xi_i\right)\right],

where Li(w,ξi)L_{i}(\textbf{w}, \xi_i) is the loss of model parameterized by the weight wRd\textbf{w} \in \mathbb R^d, ξi\xi_i is an input sample drawn from data at silo ii, and the coefficient pi>0p_i>0 specifies the relative importance of each silo. Recently, different distributed algorithms have been proposed to optimize the equation. In this work, DPASGD is used to update the weight of silo ii in each training round as follows:

wi(k+1)={jNi+{i}Ai,jwj(k),if k0(mod u+1),wi(k)αk1bh=1bLi(wi(k),ξi(h)(k)),otherwise.\textbf{w}_{i}\left(k + 1\right) = \\ \begin{cases} \sum_{j \in \mathcal{N}_i^{+} \cup{\{i\}}}\textbf{A}_{i,j}\textbf{w}_{j}\left(k\right), \\\qquad\qquad\qquad\qquad\qquad \text{if k} \equiv 0 \left(\text{mod }u + 1\right),\\ \textbf{w}_{i}\left(k\right)-\alpha_{k}\frac{1}{b}\sum^b_{h=1}\nabla L_i\left(\textbf{w}_{i}\left(k\right),\xi_i^{\left(h\right)}\left(k\right)\right), \\\qquad\qquad\qquad\qquad\qquad\qquad\qquad\quad\text{otherwise.} \end{cases}

where bb is the batch size, i,ji,j denote the silo, uu is the number of local updates, αk>0\alpha_k > 0 is a potentially varying learning rate at kk-th round, ARN×N\textbf{A} \in R^{N \times N} is a consensus matrix with non-negative weights, and Ni+\mathcal{N}_i^{+} is the in-neighbors set that silo ii has the connection to.

3.2 Multigraph for Federated Learning

Connectivity and Overlay. We consider the \textit{connectivity} Gc=(V,Ec)\mathcal{G}_c = (\mathcal{V}, \mathcal{E}_c) as a graph that captures possible direct communications among silos. Based on its definition, the connectivity is often a fully connected graph and is also a directed graph. % whenever the upload and download are set during learning. The \textit{overlay} Go\mathcal{G}_o is a connected subgraph of the connectivity graph, i.e., Go=(V,Eo)\mathcal{G}_o = (\mathcal{V}, \mathcal{E}_o), where EoEc\mathcal E_o \subset \mathcal E_c. Only nodes directly connected in the overlay graph Go\mathcal{G}_o will exchange the messages during training.

Multigraph. While the connectivity and overlay graph can represent different topologies for federated learning, one of their drawbacks is that there is only one connection between two nodes. In our work, we construct a \textit{multigraph} Gm=(V,Em)\mathcal{G}_m = (\mathcal{V}, \mathcal{E}_m) from the overlay Go\mathcal{G}_o. The multigraph can contain multiple edges between two nodes. In practice, we parse this multigraph to different graph states, each state is a simple graph with only one edge between two nodes.

In the multigraph Gm\mathcal{G}_m, the connection edge between two nodes has two types: \textit{strongly-connected} edge and \textit{weakly-connected} edge. Under both strong and weak connections, the participated nodes can transmit their trained models to their out-neighbours Ni\mathcal{N}_i^{-} or download models from their in-neighbours Ni+\mathcal{N}_i^{+}. However, in a strongly-connected edge, two nodes in the graph must wait until all upload and download processes between them are finished to do model aggregation. On the other hand, in a weakly-connected edge, the model aggregation process in each node can be established whenever the previous training process is finished by leveraging up-to-date models which have not been used before from the in-neighbours of that node.

State of Multigraph Given a multigraph Gm\mathcal{G}_m, we can parse this multigraph into different simple graphs with only one connection between two nodes (either strongly-connected or weakly-connected). We denote each simple graph as a state Gms\mathcal{G}_m^s of the multigraph.

Isolated Node. A node is called isolated when all of its connections to other nodes are weakly-connected edges. Figure.2 shows the graph concepts and isolated nodes.

Figure 2. Example of connectivity, overlay, multigraph, and a state of our multigraph. Blue node is an isolated node. Dotted line denotes a weakly-connected edge.

Next

In the next post, we will mention the delay time and cylce time also how multigraph can be constructed.

Deep Federated Learning for Autonomous Driving (Part 2)

In previous part, we have discussed about Autonomous driving FADNetwork. In this post, we will verify the effectiveness and efficiency of it.

Our source code can be found at: https://github.com/aioz-ai/FADNet

1. Experimental Setup

Udacity. We use the popular Udacity dataset to evaluate our results. We only use front-forwarded images in this dataset in our experiment. We use 55 sequences for training and 11 for testing. The training sequences are assigned randomly to different silos depending on the federated topology (i.e., Gaia or NWS).

Carla. Since the Udacity dataset is collected in the real-world environment, changing the weather or lighting conditions is not easy. To this end, we collect more simulated data in the Carla simulator. We have applied different lighting (morning, noon, night, sunrise, sunset) and weather conditions (cloudy, rain, heavy rain, wet streets, windy, snowy) when collecting the data. We have generated 73,23573,235 samples distributed over 1111 sequences of scenes.

Gazebo. Since both the Udacity and Carla datasets are collected in outdoor environments, we also employ Gazebo to collect data for autonomous navigation in indoor scenes. We use a simulated mobile robot and the built-in scenes to collect data. Table.1 shows the statistics of three datasets. We use 80%80\% of the collected data in Gazebo and Carla data for training, and the rest 20%20\% for testing.

Figure 1. Visualization of sample images in three datasets: Udacity (first row), Gazebo (second row), and Carla (third row).

DatasetTotal samplesAverage samples in each silo (Gaia)Average samples in each silo (NWS)
Udacity39,0873,5531,777
Gazebo66,8066,0733,037
Carla73,2356,6583,329

Table 1. The Statistic of Datasets in Our Experiments.

Network Topology. We conduct experiments on two topologies: the Internet Topology Zoo (Gaia), and the North America data centers (NWS). We use Gaia topology in our main experiment and provide the comparison of two topologies in our ablation study.

Training. The model in a silo is trained with a batch size of 3232 and a learning rate of 0.0010.001 using Adam optimizer. We follow the training process to obtain a global weight of all silos. The training process is conducted with 30003000 communication rounds and each silo has one NVIDIA 1080 11 GB GPU for training. Note that, one communication round is counted each time all silos have finished updating their model weights.

Baselines. We compare our results with various recent methods, including Random baseline and Constant baseline, Inception-V3, MobileNet-V2, VGG-16, and Dronet. All these methods use the Centralized Local Learning (CLL) strategy (i.e., the data are collected and trained in one local machine.) For distributed learning, we compare our Deep Federated Learning (DFL) approach with the Server-based Federated Learning (SFL) strategy. As the standard practice, we use the root-mean-square error (RMSE) metric to evaluate the results.

2. Results

Table 2 summarises the performance of our method and recent state-of-the-art approaches. We notice that our FADNet is trained using the proposed peer-to-peer DFL using the Gaia topology with 11 silos. This table clearly shows our FDANet + DFL outperforms other methods by a fair margin. In particular, our FDANet + DFL significantly reduces the RMSE in Gazebo and Carla datasets, while slightly outperforms DroNet in the Udacity dataset. These results validate the robustness of our FADNet while is being trained in a fully decentralized setting. Table 3 also shows that with a proper deep architecture such as our FADNet, we can achieve state-of-the-art accuracy when training the deep model in FL. Fig. 2 illustrates the spatial support regions when our FADNet making the prediction. Particularly, we can see that FADNet focuses on the ``line-like" patterns in the input frame, which guides the driving direction.

ArchitectureLearning MethodUdacityGazeboCarla#Params
Random-0.3010.1170.464-
Constant-0.2090.0920.348-
InceptionCLL0.1540.0850.29721,787,617
MobileNetCLL0.1420.0830.2862,225,153
VGG-16CLL0.1210.0830.3167,501,587
DroNetCLL0.1100.0820.333314,657
FADNet (ours)DFL0.1070.0690.203317,729

Table 2. Performance comparison of different architectures on the Udacity, Gazebo, and Carla datasets. The number of parameters (#Params) is also provided.

Figure 2. Spatial support regions for predicting steering angle in three datasets. In most cases, we can observe that our FADNet focuses on ``line-like” patterns to predict the driving direction.

3. Ablation Studies

Effectiveness of our DFL.

Table 3 summarises the accuracy of DroNet and our FADNet when we train them using different learning methods: CLL, SFL, and our peer-to-peer DFL. From this table, we can see that training both DroNet and FADNet with our peer-to-peer DFL clearly improves the accuracy compared with the SFL approach. This confirms the robustness of our fully decentralized approach and removes a need of a central server when we train a deep network with FL. Compared with the traditional CLL approach, our DFL also shows a competitive performance. However, we note that training a deep architecture using CLL is less complicated than with SFL or DFL. Furthermore, CLL is not a federated learning approach and does not take into account the privacy of the user data.

ArchitectureLearning MethodUdacityGazeboCarla
DroNetCLL0.1100.0820.333
SFL0.1760.0810.297
DFL (ours)0.1520.0730.244
FADNet (ours)CLL0.1420.0810.303
SFL0.1510.0710.211
DFL (ours)0.1070.0690.203

Table 3. Performance comparison of different methods.

Effectiveness of our FADNet.

Table 3 shows that apart from the learning method, the deep architectures also affect the final results. This table illustrates that our FADNet combined with DFL outperforms DroNet in all configurations. We notice that DroNet achieves competitive results when being trained with CLL. However DroNet is not designed for federated training, hence it does not achieve good accuracy when being trained with SFL or DFL. On the other hand, our introduced FADNet is particularly designed with dedicated layers to handle the data imbalance and model convergence problem in federated training. Therefore, FADNet achieves new state-of-the-art results in all three datasets.

Network Topology Analysis.

Table 4 illustrates the performance of DroNet and our FADNet when we train them using DFL under two distributed network topologies: Gaia and NWS. This table shows that the results of DroNet and FADNet under DFL are stable in both Gaia and NWS distributed networks. We note that the NWS topology has 22-silos while the Gaia topology has only 11 silos. This result validates that our FADNet and DFL do not depend on the distributed network topology. Therefore, we can potentially use them in practice with more silo data.

Network TopologyArchitectureUdacityGazeboCarla
Gaia (11 silos)DroNet0.1520.0730.244
FADNet (ours)0.1070.0690.203
NWS (22 silos)DroNet0.1570.0750.239
FADNet (ours)0.1090.0700.200

Table 4. Performance comparison of different network topologies.

Convergence Analysis.

The effectiveness of federated learning algorithms is identified through the convergence ability, including accuracy and training speed, especially when dealing with the increasing number of silos in practice. Fig.3 shows the convergence ability of our FADNet with DFL using two topologies: Gaia with 11 silos, and NWS with 22 silos. This figure shows that our proposed DFL achieves the best results in Gaia and NWS topology and converges faster than the SFL approach in both Gazebo and Carla datasets. We also notice that the performance of our DFL is stable when there is an increase in the number of silos. Specifically, training our FADNet with DFL reaches the converged point after approximately 150150s, 180180s on the NWS and Gaia topology, respectively. Fig.3 validates the convergence ability of our FADNet and DFL, especially when dealing with the increasing number of silos.

In practice, compared with the traditional CLL approach, federated learning methods such as SFL or DFL can leverage more GPUs remotely. Therefore, we can reduce the total training time significantly. However, the drawback of federated learning is we would need more GPUs in total (ideally one for each silo), and deep architecture also should be carefully designed to ensure model convergence.

Figure 3. The convergence ability of our FADNet and DFL under Gaia and NWS topology. Wall-clock time or elapsed real-time is the actual time taken from the start of the whole training process to the end, including the synchronization time of the weight aggregation process. All experiments are conducted with 3,0003,000 communication rounds.

Deployment

To verify the effectiveness of our FADNet in practice, we deploy the model trained on the Gazebo dataset on a mobile robot. The robot is equipped with a RealSense camera to capture the front RGB images. Our FADNet is deployed on a Qualcomm RB5 board to make the prediction of the steering angle for the robot. The processing time of our FADNet on the Qualcomm RB5 board is approximately 1212 frames per second. Overall, we observe that the robot can navigate smoothly in an indoor environment without colliding with obstacles. More qualitative results can be found in our supplementary material.

Conclusion

We propose a new approach to learn an autonomous driving policy from sensory data without violating the user's privacy. We introduce a peer-to-peer deep federated learning (DFL) method that effectively utilizes the user data in a fully distributed manner. Furthermore, we develop a new deep architecture - FADNet that is well suitable for distributed training. The intensive experimental results on three datasets show that our FADNet with DFL outperforms recent state-of-the-art methods by a fair margin. Currently, our deployment experiment is limited to a mobile robot in an indoor environment. In the future, we would like to test our approach with more silos and deploy the trained model using an autonomous car on man-made roads.

Deep Federated Learning for Autonomous Driving (Part 1)

Autonomous driving is an active research topic in both academia and industry. However, most of the existing solutions focus on improving the accuracy by training learnable models with centralized large-scale data. Therefore, these methods do not take into account the user's privacy. In this paper, we present a new approach to learn autonomous driving policy while respecting privacy concerns. We propose a peer-to-peer Deep Federated Learning (DFL) approach to train deep architectures in a fully decentralized manner and remove the need for central orchestration. We design a new Federated Autonomous Driving network (FADNet) that can improve the model stability, ensure convergence, and handle imbalanced data distribution problems while is being trained with federated learning methods. Intensively experimental results on three datasets show that our approach with FADNet and DFL achieves superior accuracy compared with other recent methods. Furthermore, our approach can maintain privacy by not collecting user data to a central server.

Our source code can be found at: https://github.com/aioz-ai/FADNet

1. Introduction

In this paper, our goal is to develop an end-to-end driving policy from sensory data while maintaining the user's privacy by utilizing FL. We address the key challenges in FL to make sure our deep network can achieve competitive performance when being trained in a fully decentralized manner. Fig.1 shows an overview of different learning approaches for autonomous driving. In Centralized Local Learning (CLL), the data are collected and trained in one local machine. Hence, the CLL approach does not take into account the user's privacy. The Server-based Federated Learning (SFL) strategy requires a central server to orchestrate the training process and receive the contributions of all clients. The main limitation of SFL is communication congestion when the number of clients is large. Therefore, we follow the peer-to-peer federated learningto set up the training. Our peer-to-peer Deep Federated Learning (DFL) is fully decentralized and can reduce communication congestion during training. We also propose a new Federated Autonomous Driving network (FADNet) to address the problem of model convergence and imbalanced data distribution. By training our FADNet using DFL, our approach outperforms recent state-of-the-art methods by a fair margin while maintaining user data privacy.

Figure 1. An overview of different learning methods for autonomous driving. (a) Centralized Local Learning, (b) Server-based Federated Learning, and (c) our peer-to-peer Deep Federated Learning. Red arrows denote the aggregation process between silos. Yellow lines with a red cross indicate the non-sharing data between silos.

Our contributions can be summarized as follows:

  • We propose a fully decentralized, peer-to-peer Deep Federated Learning framework for training autonomous driving solutions.
  • We introduce a Federated Autonomous Driving network that is well suitable for federated training.
  • We introduce two new datasets and conduct intensive experiments to validate our results.

2. Problem Formulation

We consider a federated network with NN siloed data centers (e.g., autonomous cars) Di\mathcal{D}_{i}, with i[1,N]i \in [1,N]. Our goal is to collaboratively train a global driving policy θ\theta by aggregating all local learnable weights θi\theta_i of each silo. Note that, unlike the popular centralized local training setup, in FL training, each silo does not share its local data, but periodically transmits model updates to other silos.

In practice, each silo has the training loss Li(ξi,θi)\mathcal{L}_i(\xi_i, \theta_i). ξi\xi_i is the ground-truth in each silo ii. Li(ξi,θi)\mathcal{L}_i(\xi_i, \theta_i) is calculated as the regression loss. This regression loss is modeled by a deep network that takes RGB images as inputs and predicts the associated steering angles.

3. Deep Federated Learning for Autonomous Driving

A popular training method in FL is to set up a central server that orchestrates the training process and receives the contributions of all clients (Server-based Federated Learning - SFL). The limitation of SFL is the server potentially represents a single point of failure in the system. We also may have communication congestion between the server and clients when the number of clients is massive. Therefore, in this work, we utilize the peer-to-peer FL to set up the training scenario. In peer-to-peer FL, there is no centralized orchestration, and the communication is via peer-to-peer topology. However, the main challenge of peer-to-peer FL is to assure model convergence and maintain accuracy in a fully decentralized training setting.

Figure 2. An overview of our peer-to-peer Deep Federated Learning method. (a) A simplified version of an overlay graph. (b) The training methodology in the overlay graph. Note that blue arrows denote the local training process in each silo; red arrows denote the aggregation process between silos controlled by the overlay graph; yellow lines with a red cross indicate the non-sharing data between silos; the arrow indicates that the process is parallel.

Fig.2 illustrates our Deep Federated Learning (DFL) method. Our DFL follows the peer-to-peer FL setup with the goal to integrate a deep architecture into a fully decentralized setting that ensures convergence while achieving competitive results compared to the traditional Centralized Local Learning or SFL approach. In practice, we can consider a silo as an autonomous car. Each silo maintains a local learnable model and does not share its data with other silos. We represent the silos as vertices of a communication graph and the FL is performed on an overlay, which is a sub-graph of this communication graph.

Designing the Overlay

Let Gc=(V,Ec)\mathcal{G}_c = (\mathcal{V}, \mathcal{E}_c) is the connectivity graph that captures the possible direct communications among NN silos. V\mathcal{V} is the set of vertices (silos), while Ec\mathcal{E}_c is the set of communication links between vertices. Ni+\mathcal{N}_i^{+} and Ni\mathcal{N}_i^{-} are in-neighbors and out-neighbors of a silo ii, respectively. As in~\cite{marfoq2020throughput}, we note that it is unnecessary to use all the connections of the connectivity graph for FL. Indeed, a sub-graph called an overlay, Go=(V,Eo)\mathcal{G}_o = (\mathcal{V}, \mathcal{E}_o) can be generated from Gc\mathcal{G}_c. In our work, Go\mathcal{G}_o is the result of Christofides’ Algorithm~\cite{monnot2003approximation}, which yields a strong spanning sub-graph of Gc\mathcal{G}_c with minimal cycle time. One cycle time or time per communication round, in general, is the time that a vertex waits for messages from the other vertices to do a computational update.

In practice, one block cycle time of an overlay Go\mathcal{G}_o depends on the delay of each link (i,j)(i, j), denoted as do(i,j)d_o(i, j), which is the time interval between the beginning of a local computation at node ii, and the receiving of ii's messages by jj. Furthermore, without concerns about access links delays between vertices, our graph is treated as an edge-capacitated network with:

do(i,j)=s×Tc(i)+l(i,j)+MB(i,j)d_o(i,j) = s \times T_c(i) + l(i,j) + \frac{M}{B(i,j)}

where Tc(i)T_c(i) is the time to compute one local update of the model; ss is the number of local computational steps; l(i,j)l(i,j) is the link latency; MM is the model size; B(i,j)B(i,j) is available bandwidth of the path (i,j)(i,j). As in~\cite{marfoq2020throughput}, we set s=1s=1.

Training Algorithm

At each silo ii, the optimization problem to be solved is:

θi=argminθiEξDi[L(ξi,θi)]\theta_i^{*} = \underset{\theta_i}{\arg\min} \underset{\xi \sim \mathcal{D}_i}{\mathbb{E}}[\mathcal{L}(\xi_i, \theta_i)]

We apply the distributed federated learning algorithm, DPASGD, to solve the optimizations of all the silos. In fact, after waiting one cycle time, each silo ii will receive parameters θj\theta_j from its in-neighbor Ni+\mathcal{N}_i^{+} and accumulate these parameters multiplied with a non-negative coefficient from the consensus matrix A\mathbf{A}. It then performs ss mini-batch gradient updates before sending θi\theta_i to its out-neighbors Ni\mathcal{N}_i^{-}, and the algorithm keeps repeating. Formally, at each iteration kk, the updates are described as:

θi(k+1)={jNi+iAi,jθj(k), if k0(mods+1),θi(k)αk1mh=1mL(θi(k),ξi(h)(k)),otherwise.\theta_{i}\left(k + 1\right) = \begin{cases} \sum_{j \in \mathcal{N}_i^{+} \cup{i}}\textbf{A}_{i,j}{\theta}_{j}\left(k\right), \textit{ if k} \equiv 0 \pmod{s + 1},\\ {\theta}_{i}\left(k\right)-\alpha_{k}\frac{1}{m}\sum^m_{h=1}\nabla \mathcal{L}\left({\theta}_{i}\left(k\right),\xi_i^{\left(h\right)}\left(k\right)\right), \text{otherwise.} \end{cases}

where mm is the mini-batch size and αk>0\alpha_k > 0 is a potentially varying learning rate.

Federated Averaging

To compute the prediction of models in all silos, we compute the average model θ\theta using weight aggregation from all the local model θi\theta_i. The federated averaging process is conducted as follow:

θ=1i=0Nλii=0Nλiθi\theta = \frac{1}{\sum^N_{i=0}{\lambda_i}} \sum^N_{i=0}\lambda_{{i}} \theta_{{i}}

where NN is the number of silos; λi={0,1}\lambda_i = \{0,1\}. Note that λi=1\lambda_i = 1 indicates that silo ii joins the inference process and λi=0\lambda_i = 0 if not. The aggregated weight θ\theta is then used for evaluation on the testing set Dtest\mathcal{D}_{test}.

4. Network Architecture

One of the main challenges when training a deep network in FL is the imbalanced and non-IID (identically and independently distributed) problem in data partitioning across silos. To overcome this problem, the learning architecture should have an appropriate design to balance the trade-off between convergence ability and accuracy performance. In practice, the deep architecture has to deal with the high variance between silo weights when the accumulation process for all silos is conducted. To this end, we design a new Federated Autonomous Driving Network, which is based on ResNet8, as shown in Fig.3.

Figure 3. Human Tracking.

In particular, our proposed FADNet first comprises an input layer normalization to improve the stability of the abstract layer. This layer aims to handle different distributions of input images in each silo. Then, a convolution layer following by a max-pooling layer is added to encode the input. To handle the vanishing gradient problem, three residual blocks are appended with a following FC layer to extract ResBlock features. However, using residual blocks increases the variance of silo weights during the aggregation process and affects the convergence ability of the model. To address this problem, we add a Global Average Pooling layer (GAP) associated with each residual block. GAP is a non-weight pooling layer which sums out the spatial information from each residual block. Thus, it is not affected by the weighted variance problem. The output of each GAP layer is passed through an Accumulation layer to accrue the Support feature. The ResBlock feature and the Support feature from GAP layers are fed into the Aggregation layer to calculate the model loss in each silo.

In our design, the Accumulation and Aggregation layers aim to reduce the variance of the global model since we need to combine multiple model weights produced by different silos. In particular, the Accumulation layer is a variant of the fully connected (FC) layer. Instead of weighting the contribution of input nodes as in FC, the Accumulation layer weights the contribution of multiple features from input layers. The Accumulation layer has a learnable weight matrix wRnw \in \mathbb{R}^\text{n}. Its number of nodes is equal to the \text{n} number of input layers. Note that the support feature from the Accumulation layer has the same size as the input. Let F={f1,f2,...,fn},fhRdF = \{f_\text{1}, f_\text{2}, ..., f_\text{n}\}, \forall f_\text{h} \in \mathbb{R}^\text{d} be the collection of n\text{n} number of the features extracted from n\text{n} input GAP layers; d\text{d} is the unified dimension. The Accumulation outputs a feature fcRdf_\text{c} \in \mathbb{R}^\text{d} in each silo ii, and is computed as:

fc=Accumulation(F)i=h=1n(whfh)if_\text{c} = Accumulation(F)_i = \sum^{\text{n}}_{\text{h}=1}(w_\text{h}f_\text{h})_i

The Aggregation layer is a fusion between the ResBlock feature extracted from the backbone and the support feature from the Accumulation layer. For simplicity, we use the Hadamard product to compute the aggregated feature. This feature is then averaged to predict the steering angle. Let fsRdf_\text{s} \in \mathbb{R}^\text{d} be the ResBlock features extracted from the backbone. The output driving policy θi\theta_i of silo ii can be calculated as:

θi=Aggregation(fs,fc)i=(fsfc)ˉi\theta_i = Aggregation(f_\text{s}, f_\text{c})_i = \bar{(f_\text{s} \odot f_\text{c})}_i

where \odot denotes Hadamard product; ()ˉ\bar{(*)} denotes the mean and we set d=6,272\text{d} = 6,272.

Next

In the next post, we will show the effectiveness and efficiency of FADNet during Federated Learning proccess.

Music-Driven Group Choreography (Part 3)

This is the final part of the series group dance choreography, In this part, we will provide detailed analyses of our proposed group dance generation method.

Experiments

AIOZ-GDANCE Statistics

Figure 1. Distribution (%) of music genres (a) and dance styles (b) in our dataset.

In Figure 1, we show the distribution of music genres and dance styles in our dataset. As illustrated in Figure 1 (Left), Pop and Electronic are popular music genres while other music genres nearly share the same distribution. Meanwhile, on the right of Figure 1, Zumba, Aerobic, and Commercial are the dominant dance styles.

Figure 2. The correlation between dance styles and number of dancers (a); and between dance styles and music genres (b).

Figure 2 (Left) shows the number of dancers in each dance style. Naturally, we see that Zumba, Aerobic, and Commercial have more dancers. On the right of this figure, we illustrate the correlation between music genres and dance styles.

Evaluation Metrics

Similar to prior works on single-dance generation, we evaluate the generated motion quality by calculating the distribution distance between the generated and the ground-truth motions using Frechet Inception Distance (FID)[1, 2]. To evaluate how well the generated 3D motion correlates to the input music, we use the Motion-Music Consistency metric (MMC) [2,3]. We also evaluate our model's ability to generate diverse dance motions when given various input music by measuring Generation Diversity (GenDiv) [2,3].

To evaluate the group dancing quality, we propose three new metrics: Group Motion Realism (GMR), Group Motion Correlation (GMC), and Trajectory Intersection Frequency (TIF). Detailed calculations of these metrics are described as follows:

Group Motion Realism (GMR). To calculate the realism between generated and ground-truth group motion, we need to find a single unified representation for all dancers' motions in the scene. Based on the kinetic features of a single motion sequence [4], we propose to calculate Group Motion Realism (GMR), smaller is better. Specifically, for each entity, we compute the velocity of each element jj of the pose vector: vtn=yt+1nytnΔtv^n_t = \frac{y^n_{t+1} - y^n_t}{\Delta t} where Δt\Delta t is the time period between two consecutive frames. Note that the pose vector of each entity at each frame consists of the root orientation, root position and joint angles. The group kinetic features of a sequence is approximated by taking the logarithm of the total kinetic energy of all group entities as:

ej=log(1+1T1Nt=1Tn=1Nmj(vt,jn)2)e_j = \log \left(1 + \frac{1}{T}\frac{1}{N} \sum_{t=1}^T \sum_{n=1}^N m_j (v^n_{t,j})^2\right)

where mjm_j is the moment of inertia or mass of each joint. We assume that mjm_j is constant with respect to time and entity. Then, we split the sequence into smaller chunks and calculate the features of these chunks. This process is identical for both the generated and ground-truth sequences. Finally, we utilize these sets of features (from generated and ground-truth group dance) to calculate the GMR using the standard FID formulation as in [1].

Group Motion Correlation (GMC). We also evaluate the synchrony and the correlation between dancers within the generated group. We assume that the correlation of movements between individuals is likely to reflect their interaction in the choreography. For every pair of motions within a group, we first align the two motion sequences using Dynamic Time Warping algorithm based on the Euclidean distance in the joint position space (obtained by SMPL joint regressor). We then calculate the mean cross-correlations between the time-aligned motion pairs using the kinetic features [4]. The generated group motion correlation degree is then calculated as the average of all motion pairs.

Trajectory Intersection Frequency (TIF). For the generated group sequences, the intersection rate is calculated over all FF frames as:

TIR=Fi,j:ijI[intersect(M(yi),M(yj))]F,\text{TIR} = \frac{\sum_{F}\sum_{i,j : i\neq j} \mathbb{I}[\text{intersect}(M(y^i),M(y^j))]}{F},

where MM is the SMPL skinning function [5] which can output a 6890-vertices human mesh from the input pose parameters yy. intersect(x,y)\text{intersect}(x,y) is a function that returns 1 if the two meshes are intersect with each other and 0 otherwise. For TIF, smaller value is better and indicates less intersection of the generated group.

Cross-entity Attention Analysis

We compare our method with FACT [2]. FACT is a recent state-of-the-art method designed for single dance generation, thus giving our method an advantage. However, it is still the closest competing method as we propose a new group dance dataset that is not available for benchmarking before. We also analyse our method with and without using Cross-entity Attention. We train all methods with mini-batch containing all dancers within the group instead of sampling each dancer independently as in FACT’s original implementation.

Figure 3. Comparison between FACT and our GDanceR. Our method handles better the consistency and cross-body intersection problem between dancers.

Table 1 shows the method comparison between the baseline FACT[2] and our proposed GDanceR with and without Cross-entity Attention. The results show that GDanceR, especially with the Cross-entity Attention, outperforms the baseline by a large margin in all metrics. In Figure 3, we also visualize the example outputs of FACT and GDanceR. It is clear that FACT does not handle well the intersection problem. This is understandable as FACT is not designed for group dance generation, while our method with the Cross-entity Attention can deal with this problem better.

Table 1. Generation results comparison on AIOZ-GDANCE dataset. w/o CA denotes without using Cross-entity Attention.

Number of Dancers Analysis.

Table 2 demonstrates the generation results of our method when we want to generate different numbers of dancers. In general, the FID, GMR, and GMC metrics do not show much correlation with the numbers of generated dancers since the results are varied. On the other hand, MMC shows its stability among all setups (0.248\sim 0.248), which indicates that our network is robust in generating motion from given music regardless of the changing of initial positions. The generation diversity (GenDiv) decreases while the intersection frequency (TIF) increases when more dancers are generated. These results show that dealing with the collision during the group generation process is worth further investigation.

Table 2. Performance of our proposed method when increasing the number of generated dancers.

Dance Style Analysis

Figure 4. Examples of generated group motions from our method.

Different dance styles exhibit different challenges in group dance generation. As shown in Table 3, Aerobic and Zumba are quite similar for generating choreography as they usually focus on workout and sporty movements. Besides, while Commercial and Irish are easier for the model to reproduce the motions, Bollywood and Samba contain highly skilled movements that are challenging to capture and represent accurately. In Figure 4, we show the generated results of GDanceR with different dance styles. Our Supplementary Material and Demonstration Video also provide more examples.

Table 3. The results of different dance styles. These results are obtained by training the model on each dance style.

Ablation on Latent Motion Fusion.

We investigate different fusion strategies between the local motion hih^i and global-aware motion gig^i to obtain the final motion representation ziz^i. Specifically, we experiment with three settings: (i) No Fusion: the final motion is the global-aware motion obtained from our Cross-entity Attention (zi=giz^i = g^i); (ii) Concatenate: the final motion is the concatenation of the local and global-aware motion (zi=[hi;gi]z^i = [h^i; g^i]); (iii) Add: the final motion is the addition between local and global (zi=hi+giz^i = h^i + g^i). Table 4 summarizes the results. We find that fusing the motion by adding both the local and global motion features achieves the best results. In this strategy, the global information between entities is encoded to the local motion in an effective way so that the final motion retain the comprehensive information of their own past motion as well as the motion of every other entity. While in the concatenation, the model is prone to overfitting due to the redundant information of both the local and global representation. On the other hand, No Fusion can degrade the amount of information of the past motion, leading to insufficient input information and the Decoder may fail to generate the temporally-coherent motion aligned with the music.

Table 4. Ablation study on different fusion strategies for the latent motion representation.

Conclusion

In summary, we have introduced AIOZ-GDANCE, the largest dataset for audio-driven group dance generation. Our dataset contains in-the-wild videos and covers different dance styles and music genres. We then propose a strong baseline along with new evaluation metrics for group dance generation task. We also perform extensive experiments to validate our method on this interesting yet unexplored problem, using our new dataset and evaluation protocols. We hope that the release of our dataset will foster more research on audio-driven group choreography.

References

[1] Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, and Sepp Hochreiter. Gans trained by a two time-scale update rule converge to a local nash equilibrium. NIPS 2017.

[2] Ruilong Li, Shan Yang, David A Ross, and Angjoo Kanazawa. Ai choreographer: Music-conditioned 3d dance generation with aist++. ICCV 2021

[3] Hsin-Ying Lee, Xiaodong Yang, Ming-Yu Liu, Ting-Chun Wang, Yu-Ding Lu, Ming-Hsuan Yang, and Jan Kautz. Dancing to music. NIPS 2019.

[4] Kensuke Onuma, Christos Faloutsos, and Jessica K Hodgins. Fmdistance: A fast and effective distance function for motion capture data. Eurographics 2008.

[5] Matthew Loper, Naureen Mahmood, Javier Romero, Gerard Pons-Moll, and Michael J. Black. SMPL: A skinned multi-person linear model. ACM Trans. Graphics, 2015

Music-Driven Group Choreography (Part 2)

In the previous post, we have introduced AIOZ-GDANCE, a new largescale in-the-wild dataset for music-driven group dance generation. On the basis of the new dataset, we introduce the first strong baseline for group dance generation that can jointly generate multiple dancing motions expressively and coherently.

Figure 1. The overall architecture of GDanceR. Our model takes in a music sequence and a set of initial positions, and then auto-regressively generates coherent group dance motions that are attuned to the input music.

Music-driven Group Dance Generation Method

Problem Formulation

Given an input music audio sequence {m1,m2,...,mT}\{m_1, m_2, ...,m_T\} with t={1,...,T}t = \{1,..., T\} indicates the index of music segments, and the initial 3D positions of NN dancers {τ01,τ02,...,τ0N}\{\tau^1_0, \tau^2_0, ..., \tau^N_0 \}, τ0iR3\tau^i_0 \in \mathbb{R}^{3}, our goal is to generate the group motion sequences {y11,...,yT1;...;y1n,...,yTn}\{y^1_1,..., y^1_T; ...;y^n_1,...,y^n_T\} where ytiy^i_t is the generated pose of ii-th dancer at time step tt. Specifically, we represent the human pose as a 72-dimensional vector y=[τ;θ]y = [\tau; \theta] where τ\tau, θ\theta represent the root translation and pose parameters of the SMPL model [1], respectively.

In general, the generated group dance motion should meet the two conditions: (i) consistency between the generated dancing motion and the input music in terms of style, rhythm, and beat; (ii) the motions and trajectories of dancers should be coherent without cross-body intersection between dancers. To that end, we propose the first baseline method, for group dance generation that can jointly generate multiple dancing motions expressively and coherently. Figure 1 shows the architecture of our proposed Music-driven 3D Group Dance generatoR (GDanceR), which consists of three main components:

  • Transformer Music Encoder.
  • Initial Pose Generator.
  • Group Motion Generator.

Transformer Music Encoder

From the raw audio signal of the input music, we first extract music features using the available audio processing library Librosa. Concretely, we extract the mel frequency cepstral coefficients (MFCC), MFCC delta, constant-Q chromagram, tempogram, onset strength and one-hot beat, which results in a 438-dimensional feature vector. We then encode the music sequence M={m1,m2,...,mT}M =\{m_1, m_2, ...,m_T\}, mtR438m_t \in \mathbb{R}^{438} into a sequence of hidden representation {a1,a2,...,aT}\{a_1, a_2,..., a_T\}, atRdaa_t \in \mathbb{R}^{d_a}. In practice, we utilize the self-attention mechanism of transformer [2] to effectively encode the multi-scale information and the long-term dependency between music frames. The hidden audio at each time step is expected to contain meaningful structural information to ensure that the generated dancing motion is coherent across the whole sequence.

Specifically, we first embed the music features mtm_t using a Linear layer followed by Positional Encoding to encode the time ordering of the sequence.

U=PE(MWau)U = \text{PE}({M} W^u_a)

where PE\text{PE} denotes the Positional Encoding, and WauR438×daW^u_a \in \mathbb{R}^{438 \times d_a} is the parameters of the linear projection layer. Then, the hidden audio information can be calculated using self-attention mechanism:

A=FF(softmax((UqUk)Tdk)Uv),Uq=UWaq,Uk=UWak,Uv=UWav\\ \mathbb{A} = \text{FF}(\text{softmax}\left(\frac{(U^q U^k)^T}{\sqrt{d_{k}}} \right) U^v ), \\ U^q = U W^q_a, \quad U^k = U W^k_a, \quad U^v = UW^v_a

where Waq,WakRda×dkW^q_a, W^k_a \in \mathbb{R}^{d_a \times d_k}, and WavRda×dvW^v_a \in \mathbb{R}^{d_a \times d_v} are the parameters that transform the linear embedding audio UU into a query UqU^q, a key UkU^k, and a value UvU^v respectively. dad_a is the dimension of the hidden audio representation while dkd_k is the dimension of the query and key, and dvd_v is the dimension of value. FF\text{FF} is a feed-forward neural network.

Initial Pose Generator

Figure 2.The Transformer Music Encoder encodes the acoustic and rhythmic information to generate the initial poses from the input positions.

Given the initial positions of all dancers, we generate the initial poses by combing the audio feature with the starting positions. We aggregate the audio representation by taking an average over the audio sequence. The aggregated audio is then concatenated with the input position and fed to a multilayer perceptron (MLP) to predict the initial pose for each dancer:

y0i=MLP([1Tt=1Tat;τ0i]),y^i_0 = \text{MLP}\left( \left[\frac{1}{T}\sum_{t=1}^T a_t ; \tau^i_0 \right] \right),

where [;][;] is the concatenation operator, τ0i\tau^i_0 is the initial position of the ii-th dancer.

Group Motion Generator

Figure 3. The Group Motion Generator auto-regressively generates coherent group dance motions based on the encoded acoustic information.

To generate the group dance motion, we aim to synthesize the coherent motion of each dancer such that it aligns well with the input music. Furthermore, we also need to maintain global consistency between all dancers. As shown in Figure 3, our Group Generator comprises a Group Encoder to encode the group sequence information and an MLP Decoder to decode the hidden representation back to the human pose space. To effectively extract both the local motion and global information of the group dance sequence through time, we design our Group Encoder based on two factors: Recurrent Neural Network [3] to capture the temporal motion dynamics of each dancer, and Attention mechanism [2] to encode the spatial relationship of all dancers.

Specifically, at each time step, the pose of each dancer in the previous frame yt1iy^i_{t-1} is sent to an LSTM unit to encode the hidden local motion representation htih^i_t:

hti=LSTM(yt1i,ht1i){h^i_t=\text{LSTM}(y^i_{t-1},h^i_{t-1})}

To ensure the motions of all dancers have global coherency and discourage strange effects such as cross-body intersection, we introduce the Cross-entity Attention mechanism. In particular, each individual motion representation is first linearly projected into a key vector kik^i, a query vector qiq^i and a value vector viv^i as follows: \begin{equation} k^i = h^i W^{k}, \quad q^i = h^i W^{q}, \quad v^i = h^i W^{v}, \end{equation} where Wq,WkRdh×dkW^q, W^k \in \mathbb{R}^{d_h \times d_k}, and WvRdh×dvW^v \in \mathbb{R}^{d_h \times d_v} are parameters that transform the hidden motion hh into a query, a key, and a value, respectively. dkd_k is the dimension of the query and key while dvd_v is the dimension of the value vector. To encode the relationship between dancers in the scene, our Cross-entity Attention also utilizes the Scaled Dot-Product Attention as in the Transformer [3].

Figure 4. The Group Encoder learns to encode the relations among dancers through our proposed Cross-entity Attention mechanism.

In practice, we find that people having closer positions to each other tend to have higher correlation in their movement. Therefore, we adopt Spacial Encoding strategy to encode the spacial relationship between each pair of dancers. The Spacial Encoding between two entities based on their distance in the 3D space is defined as follows:

eij=exp(τiτj2dτ),e_{ij} = \exp\left(-\frac{\Vert \tau^i - \tau^j \Vert^2}{\sqrt{d_{\tau}}}\right),

where dτd_{\tau} is the dimension of the position vector τ\tau. Considering the query qiq^i, which represents the current entity information, and the key kjk^j, which represents other entity information, we inject the spatial relation information between these two entities onto their cross attention coefficient:

αij=softmax((qi)kjdk+eij).\alpha_{ij} = \text{softmax}\left(\frac{(q^i)^\top k^j}{\sqrt{d_k}} + e_{ij}\right).

To preserve the spatial relative information in the attentive representation, we also embed them into the hidden value vector and obtain the global-aware representation gig^i of the ii-th entity as follows:

gi=j=1Nαij(vj+eijγ),g^i = \sum_{j=1}^N\alpha_{ij}(v^j + e_{ij}\gamma),

where γRdv\gamma \in \mathbb{R}^{d_v} is the learnable bias and scaled by the Spacial Encoding. Intuitively, the Spacial Encoding acts as the bias in the attention weight, encouraging the interactivity and awareness to be higher between closer entities. Our attention mechanism can adaptively attend to each dancer and others in both temporal and spatial manner, thanks to the encoded motion as well as the spatial information.

We then fuse both the local and global motion representation by adding hih^i and gig^i to obtain the final latent motion ziz^i. Our final global-local representation of each entity is expected to carry the comprehensive information of their own past motion as well as the motion of every other entity, enabling the MLP Decoder to generate coherent group dancing sequences. Finally, we generate the next movement yti{y}^i_t based on the final motion representation ztiz^i_t as well as the hidden audio representation ata_t, and thus can capture the fine-grained correspondence between music feature sequence and dance movement sequence:

yti=MLP([zti;at]).y^i_t = \text{MLP}([z^i_t; a_t]).

Built upon these components, our model can effectively learn and generate coherent group dance animation given several pieces of music. In the next part, we will go through the experiments and detailed studies of the method.

References

[1] Matthew Loper, Naureen Mahmood, Javier Romero, Gerard Pons-Moll, and Michael J. Black. SMPL: A skinned multiperson linear model. ACM Trans. Graphics, 2015

[2] Vaswani A, Shazeer N, Parmar N, Uszkoreit J, Jones L, Gomez AN, Kaiser Ł, Polosukhin I. Attention is all you need. NIPS 2017.

[3] Hochreiter S, Schmidhuber J. Long short-term memory. Neural computation. 1997 Nov 15;9(8):1735-80.

Music-Driven Group Choreography (Part 1)

Dancing is an important part of human culture and remains one of the most expressive physical art and communication forms. With the rapid development of digital social media platforms, creating dancing videos has gained significant attention from social communities. As a result, millions of dancing videos are created and watched daily on online platforms. Recently, studies of how to create natural dancing motion from music have attracted great attention in the research community.

Figure 1. We demonstrate the AIOZ-GDANCE dataset with in-the-wild videos, music audio, and 3D group dance motion.

Nevertheless, generating dance motion for a group of dancers remains an open problem and have not been well-investigated by the community yet. Motivated by these shortcomings and to foster research on group choreography, we establish AIOZ-GDANCE, a new largescale in-the-wild dataset for music-driven group dance generation. Unlike existing datasets that only support single dance, our new dataset contains group dance videos as shown in Figure 1, hence supporting the study of group choreography. On the basis of the new dataset, we propose the first strong baseline for group dance generation that can jointly generate multiple dancing motions expressively and coherently.

Dataset Construction

Figure 2. The pipeline of making our AIOZ-GDANCE dataset.

In this section, we will elaborate and describe the process to build our dataset from a large variety of videos available on the internet. Because our main goal is to develop a large-scale dataset with in-the-wild videos, setting up a MoCap system as in many classical approaches is not feasible. However, manually creating 3D groundtruth for millions of frames from dancing videos is also an extremely costly and tedious job. To that end, we propose a semi-automatic labeling method with humans in the loop to obtain the 3D ground truth for our dataset. The process to construct the data includes the five following key steps:

  1. Video collection
  2. Human Tracking
  3. Human Pose Estimation
  4. Local Fitting for Invidual Motions
  5. Global Scene Optimization

Data Collection and Preprocessing

Figure 3. Human Tracking.

Video Collection. We collect the in-the-wild, public domain group dancing videos along with the music from Youtube, Tiktok, and Facebook. All group dance videos are processed at 1920 × 1080 resolution and 30FPS.

Human Tracking. We perform tracking for all humans in the videos using the state-of-the-art multi-object tracker [1] to obtain the tracking bounding boxes. Note that although the tracker can produce reasonable results, there are failure cases in some frames. Therefore, we manually correct the bounding box of the incorrect cases. This tracking correction is crucial since we want the trajectory of each person to be accurately tracked in order to reconstruct their motion in latter stages.

Pose Estimation. Given the bounding boxes of each person in the video, we leverage a state-of-the-art 2D pose estimation method [2] to generate the initial 2D poses for each person. In practice, there exist some inaccurately detected keypoints due to motion blur and partial occlusion. We manually fix the incorrect cases to obtain the 2D keypoints of each human bounding box.

Local Mesh Fitting

Figure 4. Local Mesh Fitting.

To construct 3D group dance motion, we first reconstruct the full body motion for each dancer by fitting the 3D mesh. We then jointly optimize all dancer motions to construct the globally-coherent group motion. Finally, we post-process and remove wrong cases from the optimization results.

We use SMPL model [3] to represent the 3D human. The SMPL model is a differentiable function that maps the pose parameters θ\mathbf{\theta}, the shape parameters β\mathbf{\beta}, and the root translation τ\mathbf{\tau} into a set of 3D human body mesh vertices VR6890×3\mathbf{V}\in \mathbb{R}^{6890\times3} and 3D joints XRJ×3\mathbf{X}\in \mathbb{R}^{J\times3}, where JJ is the number of body joints.

Our optimizing motion variables for each individual dancer consist of a sequence of SMPL joint angles {θt}t=1T\{\mathbf{\theta}_t\}_{t=1}^T, a sequence of the root translation {τt}t=1T\{\mathbf{\tau}_t\}_{t=1}^T, and a single SMPL shape parameter β\mathbf{\beta}. We fit the sequence of SMPL motion variables to the tracked 2D keypoints by extending SMPLify-X framework [4] across the whole video sequence:

Elocal=EJ+λθEθ+λβEβ+λSES+λFEFE_{\rm local} = E_{\rm J} + \lambda_{\theta}E_{\theta} + \lambda_{\beta} E_{\beta} + \lambda_{\rm S}E_{\rm S} + \lambda_{\rm F}E_{\rm F}

where:

  • EJE_{\rm J} is the 2D reprojection term between the 2D keypoints and the 2D projection of the corresponding 3D poses.
  • EθE_{\theta} is the pose prior term from the latent space of the VPoser model [4] to encourage plausible human pose.
  • EβE_{\beta} is the shape prior term to regularize the body shape towards the mean shape of the SMPL body model.
  • ES=t=1T1θt+1θt2+j=1Jt=1T1Xj,t+1Xj,t2E_{\rm S} = \sum_{t=1}^{T-1}\Vert \mathbf{\theta}_{t+1} - \mathbf{\theta}_{t} \Vert^2 + \sum_{j=1}^J\sum_{t=1}^{T-1}\Vert \mathbf{X}_{j,t+1} - \mathbf{X}_{j,t} \Vert^2 is the smoothness term to encourage the temporal smoothness of the motion.
  • EF=t=1T1jFcj,tXj,t+1Xj,t2E_{\rm F} = \sum_{t=1}^{T-1} \sum_{j \in \mathcal{F}} c_{j,t}\Vert \mathbf{X}_{j,t+1} - \mathbf{X}_{j,t} \Vert^2 is to ensure feet joints to stay stationary when in contact (zero velocity). Where F\mathcal{F} is the set of feet joint indexes, cj,tc_{j,t} is the feet contact of joint jj at time tt.

Global Optimization

Figure 5. Global Optimization.

Given the 3D motion sequence of each dancer pp: {θtp,τtp}\{\mathbf{\theta}^p_t, \mathbf{\tau}^p_t\}, we further resolve the motion trajectory problems in group dance by solving the following objective:

Eglobal=EJ+λpenEpen+λregpEreg(p)+λdepp,p,tEdep(p,p,t)+λgcpEgc(p)E_{\rm global} = E_{\rm J} + \lambda_{\rm pen}E_{\rm pen} + \lambda_{\rm reg}\sum_{p}E_{\rm reg}(p) + \lambda_{\rm dep}\sum_{p,p',t}E_{\rm dep}(p,p',t) + \lambda_{\rm gc}\sum_{p}E_{\rm gc}(p)

EpenE_{\rm pen} is the Signed Distance Function penetration term to prevent the overlapping of reconstructed motions between dancers.

Ereg(p)=t=1Tθtpθ^tp2{E_{\rm reg}(p) =\sum_{t=1}^T\Vert \mathbf{\theta}^p_t - \hat{\mathbf{\theta}}^p_t\Vert^2} is the regularization term that prevents the motion from deviating too much from the prior optimized individual motion {θ^tp}\{\hat{\mathbf{\theta}}^p_t\} obtained by optimizing the local mesh for dancer pp.

In practice, we find that the relative depth ordering of dancers in the scene can be inconsistent due to the ambiguity of the 2D projection. To ensure the group motion quality, we watch the videos and manually provide the ordinal depth relation information of all dancers in the scene at each frame tt as follows:

rt(p,p)={1,if dancer p is closer than p1,if dancer p is farther than p0,if their depths are roughly equalr_t(p,p') = \begin{cases} 1, &\text{if dancer } p \text{ is closer than } p' \\ -1, &\text{if dancer } p \text{ is farther than } p' \\ 0, &\text{if their depths are roughly equal} \end{cases}

Given the relative depth information, we derive the depth relation term EdepE_{\rm dep}. This term encourages consistent ordinal depth relation between the motion trajectories of multiple dancers, especially when dancers partially occlude each other:

Edep(p,p,t)={log(1+exp(ztpztp)),rt(p,p)=1log(1+exp(ztp+ztp)),rt(p,p)=1(ztpztp)2,rt(p,p)=0E_{\rm dep}(p,p',t) = \begin{cases} \log(1+\exp(z^p_t - z^{p'}_t)), &r_t(p,p')=1 \\ \log(1+\exp(-z^p_t + z^{p'}_t)), &r_t(p,p')=-1 \\ (z^p_t - z^{p'}_t)^2, &r_t(p,p')=0 \\ \end{cases}

where ztpz^p_t is the depth component of the root translation τtp\mathbf{\tau}^p_t of the person pp at frame tt. Intuitively, for r(p,p)=1r(p,p')=1, zpz_p should be smaller than zpz_{p'} and otherwise.

Finally, we apply the global ground contact constraint EgcE_{\rm gc} to further ensure consistency between the motion of every person and the environment based on the ground contact information. This contact term is also needed to reduce the artifacts such as foot-skating, jittering, and penetration under the ground.

Egc(p)=t=1T1jFcj,tpXj,t+1pXj,tp2+cj,tp(Xj,tpf)n2E_{\rm gc}(p) = \sum_{t=1}^{T-1} \sum_{j \in \mathcal{F}} c^p_{j,t}\Vert \mathbf{X}^p_{j,t+1} - \mathbf{X}^p_{j,t} \Vert^2 + c^p_{j,t} \Vert (\mathbf{X}^p_{j,t} - \mathbf{f})^\top \mathbf{n}^* \Vert^2

where F\mathcal{F} is the set of feet joint indexes, n\mathbf{n}^* is the estimated plane normal and f\mathbf{f} is a 3D fixed point on the ground plane. The first term in Equation~\ref{eq_Egc} is the zero velocity constraint when the feet are in contact with the ground, while the second term encourages the feet position to stay near the ground when in contact. To obtain the ground plane parameters, we initialize the plane point f\mathbf{f} as the weighted median of all contact feet positions. The plane normal n\mathbf{n}^* is obtained by optimizing a robust Huber objective:

n=argminnXfeetH((Xfeetf)nn)+nn12,\mathbf{n}^* = \arg\min_{\mathbf{n}} \sum_{\mathbf{X}_{\rm feet}} \mathcal{H}\left((\mathbf{X}_{\rm feet} - \mathbf{f})^\top \frac{\mathbf{n}}{\Vert\mathbf{n}\Vert}\right) + \Vert \mathbf{n}^\top\mathbf{n} - 1 \Vert^2,

where H\mathcal{H} is the Huber loss function, Xfeet\mathbf{X}_{\rm feet} is the 3D feet positions of all dancers across the whole sequence that are labelled as in contact (i.e., cj,tp=1c^p_{j,t} = 1) .

How will AIOZ-GDANCE be useful to the community?

We bring up some interesting research directions that can be benefited from our dataset:

  • Group Dance Generation
  • Human Pose Tracking
  • Dance Education
  • Dance style transfer
  • Human behavior analysis

While single-person choreography is a hot research topic recently, group dance generation has not yet well investigated. We hope that the release of our dataset will foster more this research direction.

References

[1] Peize Sun, Jinkun Cao, Yi Jiang, Zehuan Yuan, Song Bai, Kris Kitani, and Ping Luo. Dancetrack: Multi-object tracking in uniform appearance and diverse motion. In CVPR, 2022

[2] Hao-Shu Fang, Shuqin Xie, Yu-Wing Tai, and Cewu Lu. RMPE: Regional multi-person pose estimation. In ICCV, 2017.

[3] Matthew Loper, Naureen Mahmood, Javier Romero, Gerard Pons-Moll, and Michael J. Black. SMPL: A skinned multiperson linear model. ACM Trans. Graphics, 2015

[4] Georgios Pavlakos, Vasileios Choutas, Nima Ghorbani, Timo Bolkart, Ahmed A. A. Osman, Dimitrios Tzionas, and Michael J. Black. Expressive body capture: 3d hands, face, and body from a single image. In CVPR, 2019.

Uncertainty-aware Label Distribution Learning for Facial Expression Recognition (Part 1)

Facial expression recognition (FER) plays an important role in understanding people's feelings and interactions between humans. Recently, automatic emotion recognition has gained a lot of attention from the research community due to its tremendous applications in education, healthcare, human analysis, surveillance or human-robot interaction. Recent FER methods are mostly based on deep learning and can achieve impressive results. The success of deep models can be attributed to large-scale FER datasets [1][2]. However, ambiguities of facial expression is still a key challenge in FER. Specifically, people with different backgrounds might perceive and interpret facial expressions differently, which can lead to noisy and inconsistent annotations. In addition, real-life facial expressions usually manifest a mixture of feelings rather than only a single emotion.

Motivation and Proposed Solution

Figure 1. Examples of real-world ambiguous facial expressions that can lead to noisy and inconsistent annotation.

As an example, Figure 1 shows that people may have different opinions about the expressed emotion, particularly in ambiguous images. Consequently, a distribution over emotion categories is better than a single label because it takes all sentiment classes into account and can cover various interpretations, thus mitigating the effect of ambiguity. However, existing large-scale FER datasets only provide a single label for each sample instead of a label distribution, which means we do not have a comprehensive description for each facial expression. This can lead to insufficient supervision during training and pose a big challenge for many FER systems.

To overcome the ambiguity problem in FER, we proposes a new uncertainty-aware label distribution learning method that constructs emotion distributions for training samples. Specifically, we leverage the neighborhood information of samples that have similar expressions to construct the emotion distributions from single labels and utilize them as training supervision signal.

Methodology

Preliminaries

We denote xX\mathbf{x} \in \mathcal{X} as the instance variable in the input space X\mathcal{X} and xi\mathbf{x}^{i} as the particular ii-th instance. The label set is denoted as Y={y1,y2,...,ym}\mathcal{Y} = \{y_1, y_2,..., y_m\} where mm is the number of classes and yjy_j is the label value of the jj-th class. The logical label vector of xi\mathbf{x}^{i} is indicated by li\mathbf{l}^{i} = (ly1i,ly2i,...,lymi)(l^{i}_{y_1}, l^{i}_{y_2}, ..., l^{i}_{y_m}) with lyji{0,1}\mathbf{l}^{i}_{y_j} \in \{0, 1\} and l1=1\| \mathbf{l} \| _1 = 1. We define the label distribution of xi\mathbf{x}^{i} as di\mathbf{d}^{i} = (dy1i,dy2i,...,dymi)(d^{i}_{y_1}, d^{i}_{y_2}, ..., d^{i}_{y_m}) with d1=1\| \mathbf{d} \| _1 = 1 and dyji[0,1]d^{i}_{y_j} \in [0, 1] representing the relative degree that xi\mathbf{x}^{i} belongs to the class yjy_j.

Most existing FER datasets assign only a single class or equivalently, a logical label li\mathbf{l}^{i} for each training sample xi\mathbf{x}^{i}. In particular, the given training dataset is a collection of nn samples with logical labels DlD_l = {(xi,li)1in}\{ (\mathbf{x}^{i}, \mathbf{l}^{i}) \vert 1 \le i \le n\}. However, we find that a label distribution di\mathbf{d}^i is a more comprehensive and suitable annotation for the image than a single label.

Inspired by the recent success of label distribution learning (LDL) in addressing label ambiguity [3], we aim to construct an emotion distribution di\mathbf{d}^i for each training sample xi\mathbf{x}^i, thus transform the training set DlD_l into DdD_d = {(xi,di)1in}\{ (\mathbf{x}^{i}, \mathbf{d}^{i}) \vert 1 \le i \le n\}, which can provide richer supervision information and help mitigate the ambiguity issue. We use cross-entropy to measure the discrepancy between the model's prediction and the constructed target distribution. Hence, the model can be trained by minimizing the following classification loss:

Lcls=i=1nCE(di,f(xi;θ))=i=1nj=1mdjilogfj(xi;θ).\mathcal{L}_{cls} = \sum_{i=1}^n \text{CE}\left(\mathbf{d}^i, f(\mathbf{x}^i; \theta)\right) = -\sum_{i=1}^n \sum_{j=1}^m \mathbf{d}_j^{i} \log f_j(\mathbf{x}^{i};\theta).

where f(x;θ)f(\mathbf{x}; \theta) is a neural network with parameters θ\theta followed by a softmax layer to map the input image x\mathbf{x} into a emotion distribution.

Overview

Figure 2. An overview of our Label Distribution Learning with Valence-Arousal (LDLVA) for facial expression recognition under ambiguity.

An overview of our method is presented in Figure 2. To construct the label distribution for each training instance xi\mathbf{x}^i, we leverage its neighborhood information in the valence-arousal space. Particularly, we identify KK neighbor instances for each training sample xi\mathbf{x}^i and utilize our adaptive similarity mechanism to determine their contribution degrees to the target distribution di\mathbf{d}^i. Then, we combine the neighbors' predictions and their corresponding contribution degrees with the provided label li\mathbf{l}^i and li\mathbf{l}^i's uncertainty factor to obtain the label distribution di\mathbf{d}^i. The constructed distribution di\mathbf{d}^i will be used as supervision information to train the model via label distribution learning.

Adaptive Similarity

We assume that the label distribution of the main instance xi\mathbf{x}^i can be computed as a linear combination of its neighbors' distributions. To determine the contribution of each neighbor, we propose an adaptive similarity mechanism that not only leverages the relationships between xi\mathbf{x}^i and its neighbors in the auxiliary space but also utilizes their feature vectors extracted from the backbone. We choose the valence-arousal [4] as the auxiliary space to construct the target label distribution. We use the KK-Nearest Neighbor algorithm to identify KK closest points for each training sample xi\mathbf{x}^i, denoted as N(i)N(i). We calculate the adaptive contribution degrees of neighbor instances as the product of the local similarity skis^i_k and the calibration score ζki\zeta^i_k as follows:

cki={ζkiski,for xkN(i),0,otherwise.c^i_k = \begin{cases} \zeta^i_k s^i_k, &\text{for } \mathbf{x}^k \in N(i), \\ 0, &\text{otherwise}. \end{cases}

where the local similarity skis^i_k is defined based on the distance between the instance and its neighbor in the valence-arousal space ai\mathbf{a}^i and ak\mathbf{a}^k

ski=exp(aiak22δ2),xkN(i)s^i_k = \exp\left(-\frac{\| \mathbf{a}^i - \mathbf{a}^k \|^2_2}{\delta^2}\right), \quad \forall \mathbf{x}^k \in N(i)

We utilize a multilayer perceptron (MLP) gg with parameter ϕ\phi to calculate the adaptive calibration score from the extracted features of the two instances vi\mathbf{v}^i and vk\mathbf{v}^k obtained from the backbone.

ζki=Sigmoid(g([vi,vk];ϕ))\zeta^i_k = Sigmoid\left(g([\mathbf{v}^i,\mathbf{v}^{k}];\phi)\right)

The proposed adaptive similarity can correct the similarity errors in the valance-arousal space, as the valence-arousal values are not always available in practice and we leverage an existing method to generate pseudo-valence-arousal.

Uncertainty-aware Label Distribution Construction

After obtaining the contribution degree of each neighbor xkN(i)\mathbf{x}^k \in N(i), we can now generate the target label distribution di\mathbf{d}^i for the main instance xi\mathbf{x}^i. The target label distribution is calculated using the logical label li\mathbf{l}^i and the aggregated distribution d~i\tilde{\mathbf{d}}^i defined as follows:

di~=kckif(xk;θ)kcki,di=(1λi)li+λidi~\tilde{\mathbf{d}^i} = \frac{\sum_k c^i_k f(\mathbf{x}^{k};\theta)}{\sum_k c^i_k}, \\ \mathbf{d}^i = (1-\lambda^i) \mathbf{l}^i + \lambda^i \tilde{\mathbf{d}^i}

where λi[0,1]\lambda^i \in [0,1] is the uncertainty factor for the logical label. It controls the balance between the provided label li\mathbf{l}^i and the aggregated distribution di~\tilde{\mathbf{d}^i} from the local neighborhood.

Intuitively, a high value of λi\lambda^i indicates that the logical label is highly uncertain, which can be caused by ambiguous expression or low-quality input images, thus we should put more weight towards neighborhood information di~\tilde{\mathbf{d}^i}. Conversely, when λi\lambda^i is small, the label distribution di\mathbf{d}^i should be close to li\mathbf{l}^i since we are certain about the provided manual label. In our implementation, λi\lambda^i is a trainable parameter for each instance and will be optimized jointly with the model's parameters using gradient descent.

Loss Function

To enhance the model's ability to discriminate between ambiguous emotions, we also propose a discriminative loss to reduce the intra-class variations of the learned facial representations. We incorporate the label uncertainty factor λi\lambda^i to adaptively penalize the distance between the sample and its corresponding class center. For instances with high uncertainty, the network can effectively tolerate their features in the optimization process. Furthermore, we also add pairwise distances between class centers to encourage large margins between different classes, thus enhancing the discriminative power. Our discriminative loss is calculated as follows:

LD=12i=1n(1λi)viμyi22+j=1mk=1kjmexp(μjμk22V)\mathcal{L}_D = \frac{1}{2}\sum_{i=1}^n (1-\lambda^i)\Vert \mathbf{v}^i - \mathbf{\mu}_{y^i} \Vert_2^2 + \sum_{j=1}^m \sum_{\substack{k=1 \\ k \neq j}}^m \exp \left(-\frac{\Vert\mathbf{\mu}_{j}-\mathbf{\mu}_{k}\Vert_2^2}{\sqrt{V}}\right)

where yiy^i is the class index of the ii-th sample while μj\mathbf{\mu}_{j}, μk\mathbf{\mu}_{k}, and μyi\mathbf{\mu}_{y^i} RV\in \mathbb{R}^V are the center vectors of the j{j}-th, k{k}-th, and yiy^i-th classes, respectively. Intuitively, the first term of LD\mathcal{L}_D encourages the feature vectors of one class to be close to their corresponding center while the second term improves the inter-class discrimination by pushing the cluster centers far away from each other. Finally, the total loss for training is computed as:

L=Lcls+γLD\mathcal{L} = \mathcal{L}_{cls} + \gamma\mathcal{L}_D

where γ\gamma is the balancing coefficient between the two losses.

References

[1] Ali Mollahosseini, Behzad Hasani, and Mohammad H. Mahoor. Affectnet: A database for facial expression, valence, and arousal computing in the wild. IEEE Transactions on Affective Computing, 2019

[2] Shan Li, Weihong Deng, and JunPing Du. Reliable crowdsourcing and deep locality-preserving learning for expression recognition in the wild. In CVPR, 2017.

[3] B. Gao, C. Xing, C. Xie, J. Wu, and X. Geng. Deep label distribution learning with label ambiguity. IEEE Transactions on Image Processing, 2017.

Uncertainty-aware Label Distribution Learning for Facial Expression Recognition (Part 2)

In the previous post, we have introduced the our proposed method for Facial Expression Recognition. In this post, we will examine the effectiveness and efficiency of the proposal.

Experimental Results

Noisy and Inconsistent Labels

Table 1. Test performance with synthetic noise.

We conduct experiments to study the robustness of our LDLVA on mislabelled data by adding synthetic noise to AffectNet, RAF-DB, and SFEW datasets. Specifically, we randomly flip the manual labels to one of the other categories. . We report the mean accuracy and standard error in Table 1. The results clearly show that our method consistently outperforms other approaches in all cases. We also observe that the improvements are even more apparent when the noise ratio increases, for example, the accuracy improvement on RAF-DB is 4.7\% with 10\% noise and 6.93\% with 30\% noise. The consistent results under various settings demonstrate the ability of our method to effectively deal with noisy annotation, which is crucial in the robustness against label ambiguity.

Table 2. Test performance with inconsistent labels between cross-datasets.

Since the annotations for large-scale FER data are commonly obtained via crowd-sourcing, this can create label inconsistency, especially between different datasets. To examine the effectiveness of our proposed methods in dealing with this problem, we also perform experiment with the cross-dataset protocol. Table 2 shows that our method achieves the best performance on all three datasets and the highest average accuracy and surpasses the current state-of-the-art methods. This confirms the advantages of our method over previous works and demonstrates the generalization ability to data with label inconsistency, which is essential for real-world FER applications.

Comparison with state of the arts

Table 3. Comparison with recent methods on the original datasets.

We further compare our method with several state of the arts on the original AffectNet, RAF-DB, and SFEW to evaluate the robustness of our method to the uncertainty and ambiguity that unavoidably exists in real-world FER datasets. The results are presented in Table 3. By leveraging label distribution learning on valence-arousal space, our model outperforms other methods and achieves state-of-the-art performance on AffectNet, RAF-DB, and SFEW. Although these datasets are considered to be "clean", the results suggest that they indeed suffer from uncertainty and ambiguity.

Qualitative Analysis

Real-world Ambiguity: To understand more about real-world ambiguous expressions, we conducted a user study in which we asked participants to choose the most clearly expressed emotion on random test images. We compare our model's predictions with the survey results in Figure 3. We can see that these images are ambiguous as they express a combination of different emotions, hence the participants do not fully agree and have different opinions about the most prominent emotion on the faces. It is further shown that our model can give consistent results and agree with the perception of humans to some degree.

Figure 3. Comparison of the results from our user study and our model.

Uncertainty Factor: Figure 4 shows the estimated uncertainty factors of some training images and their original labels. The uncertainty values decrease from top to bottom. Highly uncertain labels can be caused by low-quality inputs (as shown in Angry and Surprise columns) or ambiguous facial expressions. In contrast, when the emotions can be easily recognized as those in the last row, the uncertainty factors are assigned low values. This characteristic can guide the model to decide whether to put more weight on the provided label or the neighborhood information. Therefore, the model can be more robust against uncertainty and ambiguity.

Figure 4. Visualization of uncertainty values of some examples from RAF-DB dataset.

Conclusion

We have introduced a new label distribution learning method for facial expression recognition by leveraging structure information in the valence-arousal space to recover the intensities distributed over emotion categories. The constructed label distribution provides rich information about the emotions, thus can effectively describe the ambiguity degree of the facial image. Intensive experiments on popular datasets demonstrate the effectiveness of our method over previous approaches under inconsistency and uncertainty conditions in facial expression recognition.

References

[1] Ali Mollahosseini, Behzad Hasani, and Mohammad H. Mahoor. Affectnet: A database for facial expression, valence, and arousal computing in the wild. IEEE Transactions on Affective Computing, 2019

[2] Shan Li, Weihong Deng, and JunPing Du. Reliable crowdsourcing and deep locality-preserving learning for expression recognition in the wild. In CVPR, 2017.

[3] B. Gao, C. Xing, C. Xie, J. Wu, and X. Geng. Deep label distribution learning with label ambiguity. IEEE Transactions on Image Processing, 2017.

Light-weight Deformable Registration using Adversarial Learning with Distilling Knowledge (Part 3)

In this part, we will show the effectivness and the ablation studies of Light-weight Deformable Registration Network and Adversarial Learning Algorithm with Distilling Knowledge.

Dataset

As mentioned in [1], we train method on two types of scans: Liver CT scans and Brain MRI scans.

For Liver CT scans, we use 5 datasets:

  1. LiTS contains 131 liver segmentation scans.
  2. MSD has 70 liver tumor CT scans, 443 hepatic vessels scans, and 420 pancreatic tumor scans.
  3. BFH is a smaller dataset with 92 scans.
  4. SLIVER is a challenging dataset with 20 liver segmentation scans and annotated by 3 expert doctors.
  5. LSPIG (Liver Segmentation of Pigs) contains 17 pairs of CT scans from pigs, provided by the First Affiliated Hospital of Harbin Medical University.

For Brain MRI scans, we use 4 datasets: 1. ADNI contains 66 scans. 2. ABIDE contains 1287 scans. 3. ADHD contains 949 scans. 4. LPBA has 40 scans, each featuring a segmentation ground truth of 56 anatomical structures.

Baselines

We compare LDR ALDK method with the following recent deformable registration methods:

  • ANTs SyN and Elastix B-spline are methods that find an optimal transformation by iteratively update the parameters of the defined alignment.
  • VoxelMorph predicts a dense deformation in an unsupervised manner by using deconvolutional layers.
  • VTN is an end-to-end learning framework that uses convolutional neural networks to register 3D medical images, especially large displaced ones.
  • RCN is a recent recursive deep architecture that utilizes learnable cascade and performs progressive deformation for each warped image.

Results

Table 1 summarizes the overall performance, testing speed, and the number of parameters compared with recent state-of-the-art methods in the deformable registration task. The results clearly show that Light-weight Deformable Registration network (LDR) accompanied by Adversarial Learning with Distilling Knowledge (ALDK) algorithm significantly reduces the inference time and the number of parameters during the inference phase. Moreover, the method achieves competitive accuracy with the most recent highly performed but expensive networks, such as VTN or VoxelMorph. We notice that this improvement is consistent across all experiments on different datasets SLIVER, LiTS, LSPIG, and LPBA.

In particular, we observe that on the SLIVER dataset the Dice score of best model with 3 cascades (3-cas LDR + ALDK) is 0.3% less than the best result of 3-cas VTN + Affine, while inference speed is ?21 times faster on a CPU and the parameters used during inference is ~8 times smaller. Including benchmarking results in three other datasets, i.e., LiTS, LSPIG, and LPBA, light-weight model only trades off an average of 0.5% in Dice score and 1.25% in Jacc score for a significant gain of speed and a massive reduction in the number of parameters. We also notice that method is the only work that achieves the inference time of approximately 1s on a CPU. This makes method well suitable for deployment as it does not require expensive GPU hardware for inference.

Fig-1

Table 1: COMPARISON AMONG LDR ALDK MODEL WITH RECENT APPROACHES.

Ablation Study

Effectiveness of ALDK. Table 2 summarizes the effectiveness of Adversarial Learning with Distilling Knowledge (ALDK) when being integrated into the light-weight student network. Note that LDR without ALDK is trained using only the reconstruction loss in an unsupervised learning setup. From this table, we clearly see that ALDK algorithm improves the Dice score of the LDR tested in the SLIVER dataset by 3.4%, 4.0%, and 3.1% for 1-cas, 2-cas, and 3-cas setups, respectively. Additionally, using ALDK also increases the Jacc score by 5.2%, 4.9%, and 3.9% for 1-cas LDR, 2-cas LDR, and 3-cas LDR. These results verify the stability of adversarial learning algorithm in the inference phase, under the differences evaluation metrics, as well as the number of cascades setups. Furthermore, Table 2 also clearly shows the effectiveness and generalization of ALDK when being applied to the student network. Since the deformations extracted from the teacher are used only in the training period, adversarial learning algorithm fully maintains the speed and the number of parameters for the light-weight student network during inference. All results indicate that student network incorporated with the adversarial learning algorithm successfully achieves the performance goal, while maintaining the efficient computational cost of the light-weight setup.

Fig-2

Table 2: COMPARISON AMONG LDR ALDK MODEL WITH RECENT APPROACHES.

Accuracy vs. Complexity. Figure 1 demonstrates the experimental results from the SLIVER dataset between LDR + ALDK and the baseline VTN under multiple recursive cascades setup on both CPU and GPU. On the CPU (Figure 1-a), in terms of the 1-cascade setup, the Dice score of method is 0.2% less than VTN while the speed is ~15 times faster. The more the number of cascades is leveraged, the higher the speed gap between LDR + ALDK and the baseline VTN, e.g. the CPU speed gap is increased to ~21 times in a 3-cascades setup. We also observe the same effect on GPU (Figure 1-b), where method achieves slightly lower accuracy results than VTN, while clearly reducing the inference time. These results indicate that LDR + ALDK can work well with the teacher network to improve the accuracy while significantly reducing the inference time on both CPU and GPU in comparison with the baseline VTN network.

Fig-3

Figure 1:Plots of Dice score and Inference speed with respect to the number of cascades of the baseline Affine + VTN and LDR + ALDK. (a) for CPU speed and (b) for GPU speed. Note that results are reported for the SLIVER dataset; bars represent the CPU speed; lines represent the Dice score. All methods use an Intel Xeon E5-2690 v4 CPU and Nvidia GeForce GTX 1080 Ti GPU for inference.

Visualization

Figure 2 illustrates the visual comparison among 1-cas LDR, 1-cas LDR + ALDK, and the baseline 1-cas RCN. Five different moving images in a volume are selected to apply the registration to a chosen fixed image. It is important to note that though the sections of the warped segmentations can be less overlap with those of the fixed one, the segmentation intersection over union is computed for the volume and not the sections. In the segmented images in Figure 2, besides the matched area colored by white, we also marked the miss-matched areas by red for an easy-to-read purpose.

From Figure 2, we can see that the segmentation resutls of 1-cas LDR network without using ALDK (Figure 2-a) contains many miss-matched areas (denoted in red color). However, when we apply ALDK to the student network, the registration results are clearly improved (Figure 2-b). Overall, LDR + ALDK visualization results in Figure 2-b are competitive with the baseline RCN network (Figure 2-c). This visualization confirms that framework for deformable registration can achieve comparable results with the recent RCN network.

Fig-3

Figure 2:The visualization comparison between LDR (a), LDR + ALDK (b), and the baseline RCN (c). The left images are sections of the warped images; the right images are sections of the warped segmentation (white color represents the matched areas between warped image and fixed image, red color denotes the miss-matched areas). The segmentation visualization indicates that LDR + ALDK (b) method reduces the miss-matched areas of the student network LDR (a) significantly. Best viewed in color.

Reference

[1] Tran, Minh Q., et al. "Light-weight deformable registration using adversarial learning with distilling knowledge." IEEE Transactions on Medical Imaging, 2022.

Open Source

🐱 Github: https://github.com/aioz-ai/LDR_ALDK

Light-weight Deformable Registration using Adversarial Learning with Distilling Knowledge (Part 2)

In this part, we will introduce the Architecture of Light-weight Deformable Registration Network and Adversarial Learning Algorithm with Distilling Knowledge.

The Architecture of Light-weight Deformable Registration Network

In practice, recent deformation networks follow an encoder-decoder architecture and use 3D convolution to progressively down-sample the image, and deconvolution (transposed convolution) to recover spatial resolution [1, 3]. However, this setup consumes a large number of parameters. Therefore, the built models are computationally expensive and time-consuming. To overcome this problem we design a new light-weight student network as illustrated in Figure 1.

In particular, the proposed light-weight network has four convolution layers and three deconvolution layers. Each convolutional layer has a bank of 4×4×44 \times 4 \times 4 filters with strides of 2×2×22 \times 2 \times 2, followed by a ReLU activation function. The number of output channels of the convolutional layers starts with 1616 at the first layer, doubling at each subsequent layer, and ends up with 128128. Skip connections between the convolutional layers and the deconvolutional layers are added to help refine the dense prediction. The subnetwork outputs a dense flow prediction field, i.e., a 33 channels volume feature map with the same size as the input.

In comparison with the current state-of-the-art dense deformable registration network [3], the number of parameters of our proposed light-weight student network is reduced approximately 1010 times. In practice, this significant reduction may lead to an accuracy drop. Therefore, we propose a new Adversarial Learning with Distilling Knowledge algorithm to effectively leverage the teacher deformations ϕt\phi_t to our introduced student network, making it light-weight but achieving competitive performance.

Fig-1

Figure 1: The structure of Light-weight Deformable Registration student network. The number of channels is annotated above the layer. Curved arrows represent skip paths (layers connected by an arrow are concatenated before transposed convolution). Smaller canvas means lower spatial resolution (Source).

Adversarial Learning Algorithm with Distilling Knowledge

Our adversarial learning algorithm aims to improve the student network accuracy through the distilled teacher deformations extracted from the teacher network. The learning method comprises a deformation-based adversarial loss Ladv\mathcal{L}_{adv} and its accompanying learning strategy (Algorithm 1).

Fig-2

Figure 2: Adversarial Learning Strategy(Source).

Adversarial Loss. The loss function for the light-weight student network is a combination of the discrimination loss ldisl_{dis} and the reconstruction loss lresl_{res}. However, the forward and backward process through loss function is controlled by the Algorithm 1. In particular, the last deformation loss Ladv\mathcal{L}_{adv} that outputs the final warped image can be written as:

Ladv=γlrec+(1γ)ldis\mathcal{L}_{adv} = \gamma l_{rec} + (1 - \gamma) l_{dis}

where γ\gamma controls the contribution between lrecl_{rec} and ldisl_{dis}. Note that, the Ladv\mathcal{L}_{adv} is only applied for the final warped image.

Discrimination Loss. In the student network the discrimination loss is computed in Equation below}.

ldis=Dθ(ϕs)Dθ(ϕt)22+λ(ϕ^sDθ(ϕ^s)21)2l_{{dis}} = \left\lVert D_\mathbf{\theta}(\phi_{s}) - D_\mathbf{\theta}(\phi_{t}) \right\lVert_2^{2} + \lambda\bigg(\left\lVert \nabla_{\hat\phi_{s}}D_\mathbf{\theta}(\hat\phi_{s}) \right\lVert_2 - 1\bigg)^{2}

where λ\lambda controls gradient penalty regularization. The joint deformation ϕ^s\hat\phi_{s} is computed from the teacher deformation ϕt\phi_{t} and the predicted student deformation ϕs\phi_{s} as follow:

ϕ^s=βϕt+(1β)ϕs\hat\phi_{s} = \beta \phi_{t} + (1 - \beta) \phi_{s}

where β\beta control the effect of the teacher deformation.

In Discrimination Loss, DθD_\mathbf{\theta} is the discriminator, formed by a neural network with learnable parameters θ{\theta}. The details of DθD_\mathbf{\theta} is shown in Figure 3. In particular, DθD_\mathbf{\theta} consists of six 3D3D convolutional layers, the first layer is 128×128×128×3128 \times 128 \times 128 \times 3 and takes the c×c×c×1c \times c \times c \times 1 deformation as input. cc is equaled to the scaled size of the input image. The second layer is 64×64×64×1664 \times 64 \times 64 \times 16. From the second layer to the last convolutional layer, each convolutional layer has a bank of 4×4×44 \times 4 \times 4 filters with strides of 2×2×22 \times 2 \times 2, followed by a ReLU activation function except for the last layer which is followed by a sigmoid activation function. The number of output channels of the convolutional layers starts with 1616 at the second layer, doubling at each subsequent layer, and ends up with 256256.

Basically, this is to inject the condition information with a matched tensor dimension and then leave the network learning useful features from the condition input. The output of the last neural layer is the mean feature of the discriminator, denoted as MM. Note that in the discrimination loss, a gradient penalty regularization is applied to deal with critic weight clipping which may lead to undesired behavior in training adversarial networks.

Fig-3

Figure 3: The structure of the discriminator DθD_\mathbf{\theta} used in the Discrimination Loss (ldisl_{dis}) of our Adversarial Learning with Distilling Knowledge algorithm (Source).

Reconstruction Loss. The reconstruction loss lrecl_{rec} is an important part of a deformation estimator. Follow the VTN [3] baseline, the reconstruction loss is written as:

lrec(Imh,If)=1CorrCoef[Imh,If]l_{{rec}} (\textbf{\textit{I}}_m^h,\textbf{\textit{I}}_f) = 1 - CorrCoef [\textbf{\textit{I}}_m^h,\textbf{\textit{I}}_f]

where

CorrCoef[I1,I2]=Cov[I1,I2]Cov[I1,I1]Cov[I2,I2]CorrCoef[\textbf{\textit{I}}_1, \textbf{\textit{I}}_2] = \frac{Cov[\textbf{\textit{I}}_1,\textbf{\textit{I}}_2]}{\sqrt{Cov[\textbf{\textit{I}}_1,\textbf{\textit{I}}_1]Cov[\textbf{\textit{I}}_2,\textbf{\textit{I}}_2]}}
Cov[I1,I2]=1ωxωI1(x)I2(x)1ω2xωI1(x)yωI2(y)Cov[\textbf{\textit{I}}_1, \textbf{\textit{I}}_2] = \frac{1}{|\omega|}\sum_{x \in \omega} \textbf{\textit{I}}_1(x)\textbf{\textit{I}}_2(x) - \frac{1}{|\omega|^{2}}\sum_{x \in \omega} \textbf{\textit{I}}_1(x)\sum_{y \in \omega}\textbf{\textit{I}}_2(y)

where CorrCoef[I1,I2]CorrCoef[\textbf{\textit{I}}_1, \textbf{\textit{I}}_2] is the correlation between two images I1\textbf{\textit{I}}_1 and I2\textbf{\textit{I}}_2, Cov[I1,I2]Cov[\textbf{\textit{I}}_1, \textbf{\textit{I}}_2] is the covariance between them. ω\omega denotes the cuboid (or grid) on which the input images are defined.

Learning Strategy. The forward and backward of the aforementioned Ladv\mathcal{L}_{adv} is controlled by the adversarial learning strategy described in Algorithm 1.

In our deformable registration setup, the role of real data and attacking data is reversed when compared with the traditional adversarial learning strategy. In adversarial learning, the model uses unreal (generated) images as attacking data, while image labels are ground truths. However, in our deformable registration task, the model leverages the unreal (generated) deformations from the teacher as attacking data, while the image is the ground truth for the model to reconstruct the input information. As a consequence, the role of images and the labels are reversed in our setup. Since we want the information to be learned more from real data, the generator will need to be considered more frequently. Although the knowledge in the discriminator is used as attacking data, the information it supports is meaningful because the distilled information is inherited from the high-performed teacher model. With these characteristics of both the generator and discriminator, the light-weight student network is expected to learn more effectively and efficiently.

Reference

[1] S. Zhao, Y. Dong, E. I. Chang, Y. Xu, et al., Recursive cascaded networks for unsupervised medical image registration, in ICCV, 2019.

[2] G. Hinton, O. Vinyals, and J. Dean, Distilling the knowledge in a neural network, ArXiv, 2015.

[3] S. Zhao, T. Lau, J. Luo, I. Eric, C. Chang, and Y. Xu, Unsupervised 3d end-to-end medical image registration with volume tweening network, IEEE J-BHI, 2019.

Open Source

🐱 Github: https://github.com/aioz-ai/LDR_ALDK

Light-weight Deformable Registration using Adversarial Learning with Distilling Knowledge

Introduction: Medical image registration

Medical image registration is the process of systematically placing separate medical images in a common frame of reference so that the information they contain can be effectively integrated or compared. Applications of image registration include combining images of the same subject from different modalities, aligning temporal sequences of images to compensate for the motion of the subject between scans, aligning images from multiple subjects in cohort studies, or navigating with image guidance during interventions. Since many organs do deform substantially while being scanned, the rigid assumption can be violated as a result of scanner-induced geometrical distortions that differ between images. Therefore, performing deformable registration is an essential step in many medical procedures.

Previous Studies, Remaining Challenges, and Motivation

Recently, learning-based methods have become popular to tackle the problem of deformable registration. These methods can be split into two groups: (i) supervised methods that rely on the dense ground-truth flows obtained by either traditional algorithms or simulating intra-subject deformations. Although these works achieve state-of-the-art performance, they require a large amount of manually labeled training data, which are expensive to obtain; and (ii) unsupervised learning methods that use a similarity measurement between the moving and the fixed image to utilize a large amount of unlabelled data. These unsupervised methods achieve competitive results in comparison with supervised methods. However, their deformations are reconstructed without the direct ground-truth guidance, hence leading to the limitation of leveraging learnable information. Furthermore, recent unsupervised methods all share an issue of great complexity as the network parameters increase significantly when multiple progressive cascades are taken into account. This leads to the fact that these works can not achieve real-time performance during inference while requiring intensively computational resources when deploying.

In practice, there are many scenarios when medical image registration are needed to be fast - consider matching preoperative and intra-operative images during surgery, interactive change detection of CT or MRI data for a radiologist, deformation compensation or 3D alignment of large histological slices for a pathologist, or processing large amounts of images from high-throughput imaging methods. Besides, in many image-guided robotic interventions, performing real-time deformable registration is an essential step to register the images and deal with organs that deform substantially. Economically, the development of a CPU-friendly solution for deformable registration will significantly reduce the instrument costs equipped for the operating theatre, as it does not require GPU or cloud-based computing servers, which are costly and consume much more power than CPU. This will benefit patients in low- and middle-income countries, where they face limitations in local equipment, personnel expertise, and budget constraints infrastructure. Therefore, design an efficient model which is fast and accurate for deformable registration is a crucial task and worth for study in order to improve a variety of surgical interventions.

Contribution

Deformable registration is a crucial step in many medical procedures such as image-guided surgery and radiation therapy. Most recent learning-based methods focus on improving the accuracy by optimizing the non-linear spatial correspondence between the input images. Therefore, these methods are computationally expensive and require modern graphic cards for real-time deployment. Thus, we introduce a new Light-weight Deformable Registration network that significantly reduces the computational cost while achieving competitive accuracy (Fig.1). In particular, we propose a new adversarial learning with distilling knowledge algorithm that successfully leverages meaningful information from the effective but expensive teacher network to the student network. We design the student network such as it is light-weight and well suitable for deployment on a typical CPU. The extensively experimental results on different public datasets show that our proposed method achieves state-of-the-art accuracy while significantly faster than recent methods. We further show that the use of our adversarial learning algorithm is essential for a time-efficiency deformable registration method.

Fig-1

(a)
(b)
Figure 1: Comparison between typical deep learning-based methods for deformable registration (a) and our approach using adversarial learning with distilling knowledge for deformable registration (b). In our work, the expensive Teacher Network is used only in training; the Student Network is light-weight and inherits helpful knowledge from the Teacher Network via our Adversarial Learning algorithm. Therefore, the Student Network has high inference speed, while achieving competitive accuracy (Source).

Methodology

Method overview

We describe our method for Light-weight Deformable Registration using Adversarial Learning with Distilling Knowledge. Our method is composed of three main components: (i)) a Knowledge Distillation module which extracts meaningful deformations ϕt\bm{\phi_t} from the Teacher Network; (ii) a Light-weight Deformable Registration (LDR) module which outputs a high-speed Student Network; and (iii) an Adversarial Learning with Distilling Knowledge (ALDK) algorithm which effectively leverages teacher deformations ϕt\bm{\phi}_t to the student deformations. An overview of our proposed deformable registration method can be found in Fig.2.

Fig-2

Figure 2: An overview of our proposed Light-weight Deformable Registration (LDR) method using Adversarial Learning with Distilling Knowledge (ALDK). Firstly, by using knowledge distillation, we extract the deformations from the Teacher Network as meaningful ground-truths. Secondly, we design a light-weight student network, which has competitive speed. Finally, We employ the Adversarial Learning with Distilling Knowledge algorithm to effectively transfer the meaningful knowledge of distilled deformations from the Teacher Network to the Student Network (Source).

Since the content may over-length, in this part, we introduce the background theory for Deformable Registration and Knowledge Distillation for Deformation. In the next part, we will introduce the Architecture of Light-weight Deformable Registration Network and Adversarial Learning Algorithm with Distilling Knowledge. In the final part, we will introduce the effectiveness of the method in comparison with recent states of the arts and detailed analysis.

Background: Deformable Registration

We follow RCN [1] to define deformable registration task recursively using multiple cascades. Let Im,If\textbf{\textit{I}}_m, \textbf{\textit{I}}_f denote the moving image and the fixed image respectively, both defined over dd-dimensional space Ω\bm{\Omega}. A deformation is a mapping ϕ:ΩΩ\bm{\phi} : \bm{\Omega} \rightarrow \bm{\Omega}. A reasonable deformation should be continuously varying and prevented from folding. The deformable registration task is to construct a flow prediction function F\textbf{F} which takes Im,If\textbf{\textit{I}}_m, \textbf{\textit{I}}_ f as inputs and predicts a dense deformation ϕ\bm{\phi} that aligns Im\textbf{\textit{I}}_m to If\textbf{\textit{I}}_f using a warp operator \circ as follows:

F(n)(Im(n1),If)=ϕ(n)F(n1)(ϕ(n1)Im(n2),If)\textbf{F}^{(n)}(\textbf{\textit{I}}^{(n-1)}_m,\textbf{\textit{I}}_f)=\phi^{(n)} \circ \textbf{F}^{(n-1)}(\phi^{(n-1)} \circ \textbf{\textit{I}}^{(n-2)}_m,\textbf{\textit{I}}_f)

where F(n1)\textbf{F}^{(n-1)} is the same as F(n)\textbf{F}^{(n)}, but in a different flow prediction function. Assuming for nn cascades in total, the final output is a composition of all predicted deformations, i.e.,

F(Im,If)=ϕ(n)...ϕ(1),\textbf{F}(\textbf{\textit{I}}_m, \textbf{\textit{I}}_f)=\phi^{(n)} \circ...\circ \phi^{(1)},

and the final warped image is constructed by

Im(n)=F(Im,If)Im\textbf{\textit{I}}_{m}^{(n)}=\textbf{F}(\textbf{\textit{I}}_m,\textbf{\textit{I}}_f) \circ \textbf{\textit{I}}_m

In general, previous Equations form the hypothesis function F\mathcal{F} under the learnable parameter W\mathbf{W},

F(Im,If,W)=(vϕ,Im(n))\mathcal{F}(\textbf{\textit{I}}_{m}, \textbf{\textit{I}}_f, \mathbf{W}) = (\mathbf{v}_{\phi}, \textbf{\textit{I}}_m^{(n)})

where vϕ=[ϕ(1),ϕ(2),...,ϕ(k),...,ϕ(n)]\mathbf{v}_{\phi} = [\bm{\phi}^{(1)}, \bm{\phi}^{(2)}, ..., \bm{\phi}^{(k)},..., \bm{\phi}^{(n)}] is a vector containing predicted deformations of all cascades. Each deformation ϕ(k)\bm{\phi}^{(k)} can be computed as

ϕ(k)=F(k)(Im(k1),If,Wϕ(k))\bm{\phi}^{(k)} = {\mathcal{F}}^{(k)}\left(\textbf{\textit{I}}_{m}^{(k-1)}, \textbf{\textit{I}}_f, \mathbf{W}_{\phi^{(k)}}\right)

To estimate and achieve a good deformation, different networks are introduced to define and optimize the learnable parameter W\mathbf{W}.

Knowledge Distillation for Deformation

Knowledge distillation is the process of transferring knowledge from a cumbersome model (teacher model) to a distilled model (student model). The popular way to achieve this goal is to train the student model on a transfer set using a soft target distribution produced by the teacher model.

Different from the typical knowledge distillation methods that target the output softmax of neural networks as the knowledge, in the deformable registration task, we leverage the teacher deformation ϕt\bm{\phi}_t as the transferred knowledge. As discussed in [2], teacher networks are usually high-performed networks with good accuracy. Therefore, our goal is to leverage the current state-of-the-art Recursive Cascaded Networks (RCN) [1] as the teacher network for extracting meaningful deformations to the student network. The RCN network contains an affine transformation and a large number of dense deformable registration sub-networks designed by VTN [3]. Although the teacher network has expensive computational costs, it is only applied during the training and will not be used during the inference.

Reference

[1] S. Zhao, Y. Dong, E. I. Chang, Y. Xu, et al., Recursive cascaded networks for unsupervised medical image registration, in ICCV, 2019.

[2] G. Hinton, O. Vinyals, and J. Dean, Distilling the knowledge in a neural network, ArXiv, 2015.

[3] S. Zhao, T. Lau, J. Luo, I. Eric, C. Chang, and Y. Xu, Unsupervised 3d end-to-end medical image registration with volume tweening network, IEEE J-BHI, 2019.

Open Source

🐱 Github: https://github.com/aioz-ai/LDR_ALDK

Medical Visual Question Answering Challenges

Visual Question Answering in Medical Domain

Visual Question Answering (VQA) aims to provide a correct answer for a given question consistent with the visual content of a given image. The overarching goal of this issue is to create systems that can comprehend the contents of an image in the same way that humans do and communicate effectively about that image in natural language. It is indeed a challenging task as it necessitates the interaction and complementation of both image feature extractor and natural language processor.

In medical domain, VQA could benefit both doctors and patients. VQA systems capable of understanding medical images and answering questions related to their content could support clinical education, clinical decision, and patient education. For example, doctors could use answers provided by VQA system as support materials in decision making. It can also help doctors to get a second opinion in diagnosis and reduce the high cost of training medical professionals. While patients could ask VQA questions related to their medical images for better understanding their health.

Compared with VQA in the general domain, Med-VQA is a much more challenging problem. On one hand, clinical questions are more difficult but need to be answered with higher accuracy since they relate to health and safety

Fig-1

Figure 1: An example of Medical VQA (Source).

There are challenges that need to be addressed when using VQA in Medical domain:

  • Lack of large scale labeled datasets.
  • Using Transfer learning in medical domain.

In this blog, we will analyze this challenges.

Challenges in Medical VQA

1. Lack of large scale labeled datasets.

One major problem with medical VQA is the lack of large scale labeled training data which usually requires huge efforts to build. Difficulties in building a medical VQA dataset

  • (i) designing goal-oriented VQA systems and datasets
  • (ii) categorizing the clinical questions
  • (iii) selecting (clinically) relevant images
  • (iv) capturing the context and the medical knowledge.

The first attempt for building the dataset for medical VQA is by ImageCLEF-Med. In this, images were automatically captured from PubMed Central articles. The questions and answers were automatically generated from corresponding captions of images. By that construction, the data has high noisy level, i.e., the dataset includes many images that are not useful for direct patient care and it also contains questions that do not make any sense.

Well-annotated datasets for training Medical VQA systems are extremely lacking, but it is very laborious and expensive to obtain high-quality annotations by medical experts. The first manually constructed VQA-RAD dataset for medical VQA task is released. Unfortunately, it contains only 315 images, which prevents to directly apply the powerful deep learning models for the VQA problem.

2. Using Transfer learning in a specific domain.

Transfer learning is an important step to extract meaningful features and overcome the data limitation in the medical Visual Question Answering (VQA) task. Transfer learning, in which the pretrained deep learning models that are trained on the large scale labeled dataset such as ImageNet, is a popular way to initialize the feature extraction process. However, due to the difference in visual concepts between ImageNet images and medical images, finetuning process is not sufficient. Recently, Model Agnostic Meta-Learning (MAML) has been introduced to overcome the aforementioned problem by learning meta-weights that quickly adapt to visual concepts. However, MAML is heavily impacted by the meta-annotation phase for all images in the medical dataset. Different from normal images, transfer learning in medical images is more challenging due to:

  • (i) noisy labels may occur when labeling images in an unsupervised manner;
  • (ii) high-level semantic labels cause uncertainty during learning;
  • (iii) difficulty in scaling up the process to all unlabeled images in medical datasets.

Multiple Meta-model Quantifying for Medical Visual Question Answering

Motivation

A medical Visual Question Answering (VQA) system can provide meaningful references for both doctors and patients during the treatment process. Extracting image features is one of the most important steps in a medical VQA framework which outputs essential information to predict answers. Transfer learning, in which the pretrained deep learning models that are trained on the large scale labeled dataset such as ImageNet, is a popular way to initialize the feature extraction process. However, due to the difference in visual concepts between ImageNet images and medical images, finetuning process is not sufficient. Recently, Model Agnostic Meta-Learning (MAML) has been introduced to overcome the aforementioned problem by learning meta-weights that quickly adapt to visual concepts. However, MAML is heavily impacted by the meta-annotation phase for all images in the medical dataset. Different from normal images, transfer learning in medical images is more challenging due to:

  • (i) noisy labels may occur when labeling images in an unsupervised manner;
  • (ii) high-level semantic labels cause uncertainty during learning;
  • (iii) difficulty in scaling up the process to all unlabeled images in medical datasets.

Federated Learning, Challenges and Hot Trends for research

Federated learning and its remaining challenge

Federated learning is the process of training statistical models via a network of distant devices or siloed data centers, such as mobile phones or hospitals, while keeping data locally. In terms of federated learning, there are five major obstacles that have a direct impact on the paper publishing trend.

1. Expensive Communication

Due to the internet connection, huge number of users, and administrative costs, there is a bottleneck in communication between devices and server-devices.

2. Systems Heterogeneity

Because of differences in hardware (CPU, RAM), network connection (3G, 4G, 5G, wifi), and power, each device in federated networks may have different storage, computational, and communication capabilities (battery level).

3. Statistical Heterogeneity

Devices routinely create and collect data in non-identically dispersed ways across the network; for example, in the context of a next word prediction, mobile phone users employ a variety of languages. Furthermore, the quantity of data points on different devices may differ greatly, and an underlying structure may exist that describes the interaction between devices and their related distributions. This data generation paradigm violates frequently-used independent and identically distributed (I.I.D. problem) assumptions in distributed optimization, increases the likelihood of stragglers, and may add complexity in terms of modeling, analysis, and evaluation.

4. Privacy Concerns

In federated learning applications, privacy is a crucial problem. Federated learning takes a step toward data protection by sharing model changes, such as gradient information, rather than the raw data created on each device. Nonetheless, transmitting model updates during the training process may divulge sensitive information to a third party or the central server.

5. Domain transfer

Not any task can be applied to the federated learning paradigm to finish their training process due to the aforementioned four challenges.

Hot trends

Data distribution heterogeneity and label inadequacy.

  • Distributed Optimization
  • Non-IID and Model Personalization
  • Semi-Supervised Learning
  • Vertical Federated Learning
  • Decentralized FL
  • Hierarchical FL
  • Neural Architecture Search
  • Transfer Learning
  • Continual Learning
  • Domain Adaptation
  • Reinforcement Learning
  • Bayesian Learning

Security, privacy, fairness, and incentive mechanisms:

  • Adversarial-Attack-and-Defense
  • Privacy
  • Fairness
  • Interpretability
  • Incentive Mechanism

Communication and computational resource constraints, software and hardware heterogeneity, the FL system

  • Communication-Efficiency
  • Straggler Problem
  • Computation Efficiency
  • Wireless Communication and Cloud Computing
  • FL System Design

Models and Applications

  • Models
  • Natural language Processing
  • Computer Vision
  • Health Care
  • Transportation
  • Recommendation System
  • Speech
  • Finance
  • Smart City
  • Robotics
  • Networking
  • Blockchain
  • Other

Benchmark, Dataset, and Survey

  • Benchmark and Dataset
  • Survey

Deep Metric Learning Meets Deep Clustering - An Novel Unsupervised Approach for Feature Embedding

Motivation

The objective of deep distance metric learning (DML) is to train a deep learning model that maps training samples into feature embeddings that are close together for samples that belong to the same category and far apart for samples from different categories. Traditional DML approaches require supervised information, i.e., class labels, to supervise the training. Although the supervised DML achieves impressive results on different tasks, it requires large amount of annotated training samples to train the model. Unfortunately, such large datasets are not always available and they are costly to annotate for specific domains. That disadvantage also limits the transferability of supervised DML to new domain/applications which do not have labeled data. These reasons have motivated recent studies aiming at learning feature embeddings without annotated datasets --- unsupervised deep distance metric learning (UDML). Our study is in that same direction, i.e., learning embeddings from unlabeled data.

There are two main challenges for UDML:

  • Firstly, how to define positive and negative samples for a given anchor data point, such that we can apply distance-based losses, e.g., pairwise loss or triplet loss, in the embedding space.
  • Secondly, how to make the training efficient, given a large number of pairs or triplets of samples, in the order of O(N2)\mathcal{O}(N^2) or O(N3)\mathcal{O}(N^3), respectively, in which NN is the number of training samples.

In this paper, we propose a new method that utilizes deep clustering for deep metric learning to address the two challenges mentioned above. In particular,

  • We propose to use a deep clustering loss to learn centroids, i.e., pseudo labels, that represent semantic classes.
  • During learning, these centroids are also used to reconstruct the input samples. It hence ensures the representativeness of centroids — each centroid represents visually similar samples. Therefore, the centroids give information about positive (visually similar) and negative (visually dissimilar) samples.
  • Based on pseudo labels, we propose a novel unsupervised metric loss which enforces the positive concentration and negative separation of samples in the embedding space.

Methodology

Fig-1

Figure 1: Illustration of the proposed framework which consists of an encoder (G), an embedding module (F), a decoder (D) and three losses, i.e., clustering loss LrimL_{rim}, reconstruction loss LrecL_{rec} and metric loss LmL_m. The details are presented in the text (Source).

The proposed framework is presented in Figure 1.

  • For every original image in a batch, we make an augmented version by using a random geometric transformation.
  • The input images are fed into the backbone network which is also considered as the encoder (G) to get image representations.
  • The image representations are passed through the embedding module which consists of fully connected and L2 normalization layers (F), which results in unit norm image embeddings.
  • The clustering module takes image embeddings as inputs, performs the clustering with a clustering loss, and outputs the cluster assignments.
  • Given the cluster assignments, centroid representations are computed from image representations, which are then passed through the decoder (D) with a reconstruction loss to reconstruct images that belong to the corresponding clusters.
  • The centroid representations are also passed through the embedding module (F) to get centroid embeddings. The centroid embeddings and image embeddings are used as inputs for the metric loss.

Discriminative Clustering

We formulate the clustering of embedding features as a classification problem. Given a set of embedding features X={xi}i=1mR128×mX=\{x_i\}_{i=1}^m \in \mathbb{R}^{128 \times m} in a batch and the number of clusters KmK\le m (i.e., the number of clusters KK is limited by the batch size mm). The cluster assignment for xx then is estimated by c=argmaxcyc^* = argmax_c y. Let Y={yi}i=1mY=\{y_i\}_{i=1}^m be the set of softmax outputs for XX.

Inspired by Regularized Information Maximization (RIM), we use the following objective function (1) for the clustering.

Lrim=R(θ)λ[H(Y)H(YX)](1)L_{rim} = \mathcal{R}(\theta) - \lambda \left[ H(Y) - H(Y|X) \right] (1)

where H(.)H(.) and H(..)H(.|.) are entropy and conditional entropy, respectively; R(θ)\mathcal{R}(\theta) regularizes the classifier parameters (in this work we use l2l_2 regularization); λ\lambda is a weighting factor to control the importance of two terms.

Minimizing (1) is equivalent to maximizing H(Y)H(Y) and minimizing H(YX)H(Y|X). Increasing the marginal entropy H(Y)H(Y) encourages cluster balancing, while decreasing the conditional entropy H(YX)H(Y|X) encourages cluster separation.

Reconstruction

In order to enhance the representativeness of centroids, we introduce a reconstruction loss (2) that penalizes high reconstruction errors from centroids to corresponding samples. Specifically, the decoder takes a centroid representation of a cluster and minimizes the difference between input images that belong to the cluster and the reconstructed image from the centroid representation.

Lrec=1mj=1KIiXjIiD(rj)2,(2)L_{rec} = \frac{1}{m} \sum_{j=1}^K\sum\limits_{I_i\in X_j}||I_i - D(r_j)||^2, (2)

where D(.)D(.) is the decoder which reconstructs samples in the batch using their corresponding centroid representations and mm is the number of images in the batch.

Metric loss

Let fiR128f_i \in \mathbb{R}^{128} and fi^R128\hat{f_i} \in \mathbb{R}^{128} be the image embeddings of IiI_i and Ii^\hat{I_i}, respectively. The proposed metric loss (3) aims to minimize the distance between fif_i and fi^\hat{f_i} while pushing fif_i far away from negative clusters.

Lm=ilog(l(Ii,Ii^))ij=1,jqKlog(l(Ii,cj))(3)L_{m} = -\sum_i log \left( l(I_i,\hat{I_i}) \right) - \sum_i\sum\limits_{j=1, j \neq q}^K log \left(l(I_i,c_j) \right) (3)

Final Loss

The network in Figure 1 is trained in an end-to-end manner with the following multi-task loss.

L=αLm+βLrim+γLrec,(4)L = \alpha L_{m} + \beta L_{rim} + \gamma L_{rec}, (4)

where LmL_m is the center-based softmax loss (3) for deep metric learning, LrimL_{rim} is the clustering loss (1), and LrecL_{rec} is the reconstruction loss (2).

Experimental Results

Ablation Study

We denote our model with:

  • only clustering loss (1) as only LrimL_{rim}.
  • both clustering and the metric losses (2) and (3) as Center-based Softmax (CBS).
  • Center-based Softmax with Reconstruction (CBSwR).
Tab-1

Table 1: The impact of each loss component on the performance on CUB200-2011 dataset and the comparison to the baseline (Source).

Tab-2

Table 2: The impact of each loss component on the performance on Car196 dataset and the comparison to the baseline (Source).

Tables 1 and 2 present the comparative results between methods. The results show that using only the clustering loss, the accuracy is significantly lower than the baseline SME. However, when using the centroids from the clustering for calculating the metric loss (i.e., CBS), it gives the performance boost over the baseline (i.e., SME). Furthermore, the reconstruction loss enhances the representativeness of centroids, as confirmed by the improvements of CBSwR over CBS on both datasets.

Tab-3

Table 3: The training time (seconds) of different methods on CUB200-2011 and Car196 datasets with 20 epochs. The models are trained on a NVIDIA GeForce GTX 1080-Ti GPU (Source).

Tab-4

Table 4: The impact of the number of clusters of the final model CBSwR on the performance on CUB200-2011 dataset (Source).

Table 3 presents the training time of different methods on the CUB200-2011 and Car196 datasets. Although the asymptotic complexity of CBSwR for training one batch is O(Km)\mathcal{O}(Km), it also consists of a decoder part which affects the real training. It is worth noting that the decoder is only involved during training. During testing, our method has similar computational complexity as SME.

Table 4 presents the impact of the number of clusters KK in the clustering loss on the CUB200-2011 dataset with our proposed model CBSwR (recall that the number of clusters KK is limited by the batch size mm). During training, the number of samples per clusters vary depending on batches and the number of clusters. At K=32K=32 which is our final setting, the number of samples per cluster varies from 2 to 11, on the average. The retrieval performance is just slightly different for the different number of clusters. This confirms the robustness of the proposed method w.r.t. the number of clusters.

Comparison to the state of the art

Tab-5

Table 5: Clustering and Recall performance on the CUB200-2011 dataset (Source).

Tab-6

Table 6: Clustering and Recall performance on the Car196 dataset (Source).

Table 5 presents the comparative results on CUB200-2011 dataset. In terms of clustering quality (NMI metric), the proposed method and the state-of-the-art UDML methods MOM and SME achieve comparable accuracy. However, in terms of retrieval accuracy R@K, our method outperforms other approaches. Our proposed method is also competitive to most of the supervised DML methods.

Table 6 presents comparative results on Car196 dataset. Compared to unsupervised methods, the proposed method outperforms other approaches in terms of retrieval accuracy at all ranks of K. Our method is comparable to other unsupervised methods in terms of clustering quality.

Quantity Visualization

Fig-5

Figure 2: Barnes-Hut t-SNE visualization of our embedding on the CUB200-2011 dataset.

Figure 2 shows the t-SNE plots on our learned embedding features on CUB200-2011. We can see that our embedding produces reasonable results in grouping similar visual objects despite the significant variations in view-point, pose, and configuration.

Conclusion

We propose a new method that utilizes deep clustering for deep metric learning to address the two challenges in UDML, i.e., positive/negative mining and efficient training. The method is based on a novel loss that consists of a learnable clustering function, a reconstruction function, and a center-based metric loss function. Our experiments on CUB200-2011 and Car196 datasets show state-of-the-art performance on the retrieval task, compared to other unsupervised learning methods.

Open Source

🐱 Github: https://github.com/aioz-ai/BMVC20_CBSwR

Multiple interaction learning with question-type prior knowledge for constraining answer search space in visual question answering.

Different approaches have been proposed to Visual Question Answering (VQA). However, few works are aware of the behaviors of varying joint modality methods over question type prior knowledge extracted from data in constraining answer search space, of which information gives a reliable cue to reason about answers for questions asked in input images. In this blog, we share a novel VQA model that utilizes the question-type prior information to improve VQA by leveraging the multiple interactions between different joint modality methods based on their behaviors in answering questions from different types. The solid experiments on two benchmark datasets, i.e., VQA 2.0 and TDIUC, indicate that the proposed method yields the best performance with the most competitive approaches.

Introduction

There are works that consider types of question as the side information whichgives a strong cue to reason about the answer. However, the relation between question types and answers from training data have not been investi-gated yet. Fig. 1 shows the correlation between question types and some answersin the VQA 2.0 dataset. It suggests that a question regarding the quantityshould be answered by a number, not a color. The observation indicated that theprior information got from the correlations between question types and answers open an answer search space constrain for the VQA model. The search spaceconstrain is useful for VQA model to give out final prediction and thus, improvethe overall performance. The Fig. 1 is consistent with our observation, e.g., itclearly suggests that a question regarding the quantity should be answered by anumber, not a color.
Fig1

Figure 1. The distribution of candidate answers in each question type in VQA 2.0.
Although different joint modality methods or attention mechanisms have been proposed, we hypothesize that each method may capture different aspects of the input. That means different attentions may provide different answers for questions belonged to different question types.
Fig.2 shows examples in which the attention models (SAN and BAN) attend on different regions of input images when dealing with questions from different types. Unfortunately, most of recent VQA systems are based on single attention models BAN2, SAN, MLP, MCB, STL. From the above observation, it is necessary to develop a VQA system which leverages the power of different attention models to deal with questions from different question types.
Fig1
Figure 2. Examples of attention maps of different attention mechanisms. BAN and SAN identify different visual areas when answering questions from different types.

Methodology

The proposed multiple interaction learning with question-type prior knowledge (MILQT) is illustrated in Fig. 3. Similar to the most of the VQA systems, multiple interaction learning with question-type prior knowledge (MILQT) consists of the joint learning solution for input questions and images, followed by a multi-class classification over a set of predefined candidate answers. However, MILQT allows to leverage multiple joint modality methods under the guiding of question-types to output better answers.
Fig3

Figure 3. The introduced MILQT for VQA.
As in Fig.3, MILQT consists of two modules: Question-type awareness A\mathcal{A}, and Multi-hypothesis interaction learning M\mathcal{M}. The first module aims to learn the question-type representation, which is further used to enhance the joint visual-question embedding features and to constrain answer search space through prior knowledge extracted from data. Based on the question-type information, the second module aims to identify the behaviors of multiple joint learning methods and then justify adjust contributions to giving out final predictions.

Question representation. Given an input question, follow the recent state-of-the-art BAN, we trim the question to a maximum of 12 words. The questions that are shorter than 12 words are zero-padded. Each word is then represented by a 600-D vector that is a concatenation of the 300-D GloVe word embedding and the augmenting embedding from training data. This step results in a sequence of word embeddings with size of 12×60012 \times 600 and is denoted as fwf_w. In order to obtain the intent of question, the fwf_w is passed through a Gated Recurrent Unit (GRU) which results in a 1024-D vector representation fqf_q for the input question.

Image representation. We use bottom-up attention, i.e. an object detection which takes as FasterRCNN backbone, to extract image representation. At first, the input image is passed through bottom-up networks to get K×2048K \times 2048 bounding box representation which is denotes as fvf_v in Fig. 3.

Multi-level multi-modal fusion. Unlike the previous works that perform only one level of fusion between linguistic and visual features that may limit the capacity of these models to learn a good joint semantic space. In our work, a multi-level multi-modal fusion that encourages the model to learn a better joint semantic space is introduced which takes the question-type representation got from question-type classification component as one of inputs.

  • First level multi-modal fusion: The first level fusion is similar to previous works. Given visual features fvf_v, question features fqf_{q}, and any joint modality mechanism,
    we combines visual features with question features and learn attention weights to weight for visual and/or linguistic features. Different attention mechanisms have different ways for learning the joint semantic space. The output of first level multi-modal fusion is denoted as fattf_{att} in the Fig.3.
  • Second level multi-modal fusion: In order to enhance the joint semantic space, the output of the first level multi-modal fusion fattf_{att} is combined with the question-type feature fqtf_{qt}, which is the output of the last FC layer of the Question-type classification component.
    We try two simple but effective operators, i.e. element-wise multiplication --- EWM or element-wise addition --- EWA, to combine fattf_{att} and fqtf_{qt}. The output of the second level multi-modal fusion, which is denoted as fattqtf_{att-qt} in Fig.3, can be seen as an attention representation that is aware of the question-type information.

Given an attention mechanism, the fattqtf_{att-qt} will be used as the input for a classifier that predicts an answer for the corresponding question. This is shown at the Answer prediction boxes in the Fig.3.

Multi-hypothesis interaction learning As presented in Fig.3, MILQT allows to utilize multiple hypotheses (i.e., joint modality mechanisms). Specifically, we propose a multi-hypothesis interaction learning design M\mathcal{M} that takes answer predictions produced by different joint modality mechanisms and interactively learn to combine them.
Let gRA×Jg \in \R^{A \times J} be the matrix of predicted probability distributions over AA answers from the JJ joint modality mechanisms. M\mathcal{M} outputs the distribution ρRA\rho \in \R^{A}, which is calculated from gg through Equation below:

ρ=M(g,wmil)=j(mqtansTwmilg)\rho = \mathcal{M} \left(g,w_{mil}\right) = \sum_{j}\left(m^T_{qt-ans}w_{mil} \odot g\right)

wmilRP×Jw_ {mil} \in \textbf{R}^{P \times J} is the learnable weight which control the contributions of JJ considered joint modality mechanisms on predicting answer based on the guiding of PP question types; \odot denotes Hardamard product.

Results

Experiments on VQA 2.0 test-dev and test-standard.

We evaluate MILQTon the test-dev and test-standard of VQA 2.0 dataset. To train the model,similar to previous works, we use both training set and validationset of VQA 2.0. We also use the Visual Genome as additional training data. MILQT consists of three joint modality mechanisms, i.e., BAN-2, BAN-2-Counter, and SAN accompanied with the EWM for the multi-modal fusion, andthe predicted question type together with the prior information to augment theVQA loss. Table 4 presents the results of different methods on test-dev and test-std of VQA 2.0. The results show that our MILQT yields the good performance with the most competitive approaches.

Tab-2

Table 1. Comparison to the state of the arts on the test-dev and test-standard of VQA 2.0. For fair comparison, Glove embedding and GRU are leveraged for question embedding and Bottom-up features are used to extract visual information. CMP, i.e.Cross-Modality with Pooling, is the LXMERT with the aforementioned setup (Source).

Experiments on TDIUC.

In order to prove the stability of MILQT, we evaluate MILQT on TDIUC dataset.
The results in Table.2 show that the proposed model establishes the state-of-the-art results on both evaluation metrics Arithmetic MPT and Harmonic MPT. Specifically, our model significantly outperforms the recent QTA, i.e., on the overall, we improve over QTA 6.1%6.1\% and 11.1%11.1\% with Arithemic MPT and Harmonic MPT metrics, respectively. It is worth noting that the results of QTA in Table. 2, which are cited from QTA, are achieved when QTA used the one-hot predicted question type of testing question to weight visual features. When using the groundtruth question type to weight visual features, QTA reported 69.11%69.11\% and 60.08%60.08\% for Arithemic MPT and Harmonic MPT metrics, respectively. Our model also outperforms these performances a large margin, i.e., the improvements are 3.9%3.9\% and 6.8%6.8\% for Arithemic MPT and Harmonic MPT metrics, respectively.

Tab-2

Table 2. The comparative results between the proposed model and other models onthe validation set of TDIUC (Source).

Conclusion

We present a multiple interaction learning with question-type prior knowledge for constraining answer search space--- MILQT that takes into account the question-type information to improve the VQA performance at different stages. The system also allows to utilize and learn different attentions under a unified model in an interacting manner. The extensive experimental results show that all proposed components improve the VQA performance. We yields the best performance with the most competitive approaches on VQA 2.0 and TDIUC dataset.

Open Source

Github: https://github.com/aioz-ai/ECCVW20_MILQT

A Brief Introduction to Visual Question Answering

1. Visual Question Answering - Overview

Visual Question Answering (VQA) aims to figure out a correct answer for a given question consistent with the visual content of a given image. The overarching goal of this issue is to create systems that can comprehend the contents of an image in the same way that humans do and communicate effectively about that image in natural language. It is indeed a challenging task as it necessitates the interaction and complementation of both image feature extractor and natural language processor.

There are two main variants of VQA which are Free-Form Opened-Ended (FFOE) VQA and Multiple Choice (MC) VQA. In FFOE VQA, an answer is a free-form response to a given image-question pair input, while in MC VQA, an answer is chosen from an answer list for a given image-question pair input. The discussion of VQA variants will be shared in the next post.

2. Approaches for solving VQA task.

There are three main approaches for VQA:

  • *Compositional VQA models:* the questions are interpreted as a set of many sub-tasks.
  • Bayesian and Question-Aware models: this method is not suitable for use in systems that respond to image-related questions. Since the algorithm based on this method does not try looking at the picture and instead predicts the response based on the Bayesian model by determining the probability of the words in the dataset's responses.
  • Attention based models: this method try to learn the interaction between image and question features in VQA task through a module called attention. Then, the joint features got from that module are leveraged for answering the corresponding question.

The final one is the most successful approach since recent states of the arts included attention mechanisms.

3. Attention based VQA approach.

In general, attention based VQA approaches have four main steps (See Figure 1):

  • Visual Representation: Encode the information from the image into vector(s) by using Convolutional Neural Network (CNN)
  • Textual Representation: Encode the information of question into vector(s) by using Embedding.
  • Joint Representation: A further step to learn the interaction between question(s) and image(s). Output joint features can be vector(s).
  • Answer prediction: the joint features from the previous step are then passed through this module to obtain the predicted answer. This module is mostly formed by a Classification.

Overall

Figure 1. The general approach for Visual Question Answering.

3.1. Visual Representation

The basic attributes or aspects that clearly help us recognize a specific object, image, or something are known as features. The distinguishing characteristics are the distinct properties. When operating on a VQA dataset, we must extract the features of various images in order to separate the images based on specific features or aspects. Image features are one of the most important pieces of information for a VQA system to output the correct answer.

Convolutional neural networks have emerged as the gold standard for image pattern recognition. An input image is converted into image features after it is passed through a convolutional network. Each filter in a CNN layer detects various patterns, such as corners, vertex, shapes, curves, and symmetries (See Figure 2).

Img Extraction

Figure 2. An example of feature extraction in VQA classification.

The majority of VQA literature employs CNNs for image processing. The network's final layer is removed, and the remaining network is used to extract image features. For image representation in VQA, objects in images represented by features extracted from an object detector such as the Faster-RCNN bottom-up model.

3.2. Textual Representation

Textual embeddings can be offered in a variety of ways. Count-based and frequency-based techniques such as count vectorization and TF-IDF are examples of older approaches. There are also prediction-based approaches such as a continuous bag of words and skip grams. Pretrained Word2Vec models are also openly accessible. Embeddings can also be created using deep learning architectures such as RNNs, LSTMs, GRUs, and 1-D CNNs. LSTMs are one of the most often used in VQA literature. For question embedding in VQA, Glove or BERT are used widely for capturing the representation of words and sentences in different contexts (See Figure 3 for a sample structure of question embedding).

QEmb

Figure 3. An example of question embedding for VQA.

3.3. Joint Representation

In current VQA systems, the joint modality component plays an essential role since it would learn meaningful joint representations between linguistic and visual inputs by applying the attention mechanism. There are many works that learn the interaction between question and image. For instance, a novel trilinear interaction model which simultaneously learns high level associations between image, question and answer information- CTI (Do et al. 2018). See Figure 4 for more details.

CTI

Figure 4. Compact Trilinear Interaction mechanism for VQA (Source).

3.4. Answer Prediction

In most recent works, the joint features got from the attention mechanism is then passed through a classifier to output predicted answer. However, more modules can also be applied to produce external knowledge and deal with difficult questions. ````

Overcoming Data Limitation in Medical Visual Question Answering

What are the difficulties when dealing with Medical VQA task?

Visual Question Answering (VQA) aims to provide a correct answer to a given question such that the answer is consistent with the visual content of a given image.

In medical domain, VQA could benefit both doctors and patients. For example, doctors could use answers provided by VQA system as support materials in decision making, while patients could ask VQA questions related to their medical images for better understanding their health.

Fig-1

Figure 1: An example of Medical VQA (Source).

However, one major problem with medical VQA is the lack of large scale labeled training data which usually requires huge efforts to build.

  • The first attempt for building the dataset for medical VQA is by ImageCLEF-Med. In this, images were automatically captured from PubMed Central articles. The questions and answers were automatically generated from corresponding captions of images. By that construction, the data has high noisy level, i.e., the dataset includes many images that are not useful for direct patient care and it also contains questions that do not make any sense.
  • Recently, the first manually constructed VQA-RAD dataset for medical VQA task is released. Unfortunately, it contains only 315 images, which prevents to directly apply the powerful deep learning models for the VQA problem. One may think about the use of transfer learning in which the pretrained deep learning models that are trained on the large scale labeled dataset such as ImageNet are used for finetuning on the medical VQA. However, due to difference in visual concepts between ImageNet images and medical images, finetuning with very few medical images is not sufficient.

Therefore it is necessary to develop a new VQA framework that can improve the accuracy while still only needs a small labeled training data.

The motivation for our approach to overcome the data limitation of medical VQA comes from two observations:

  • Firstly, we observe that there are large scale unlabeled medical images available. These images are from same domain with medical VQA images. Hence if we train an unsupervised deep learning model using these unlabeled images, the trained weights may be easier to be adapted to the medical VQA problem than the pretrained weights on ImageNet images.
  • Another observation is that although the labeled dataset VQA-RAD is primarily designed for VQA, by spending a little effort, we can extract the new class labels for that dataset. The new class labels allow us to apply the recent meta-learning technique for learning meta-weights, that can be quickly adapted to the VQA problem later.

Methodology

The proposed medical VQA framework is presented in Figure 2. In our framework, the image feature extraction component is initialized by pretrained weights from MAML and CDAE. After that, the VQA framework will be finetuned in an end-to-end manner on the medical VQA data. In the following sections, we detail the architectures of MAML, CDAE, and our framework.

Fig-2

Figure 2: The proposed medical VQA. The image feature extraction is denoted as 'Mixture of Enhanced Visual Features (MEVF)' and is marked with the red dashed box. The weights of MEVF are intialized by MAML and CDAE (Source).

Model-Agnostic Meta-Learning -- MAML

The MAML model consists of four 3×33\times3 convolutional layers with stride 22 and is ended with a mean pooling layer; each convolutional layer has 6464 filters and is followed by a ReLu layer.

We create the dataset for training MAML by manually reviewing around three thousand question-answer pairs from the training set of VQA-RAD dataset. In our annotation process, images are split into three parts based on its body part labels (head, chest, abdomen). Images from each body part are further divided into three subcategories based on the interpretation from the question-answer pairs corresponding to the images. These subcategories are: 1. normal images in which no pathology is found. 2. abnormal present images in which there are the existence of fluid, air, mass, or tumor. 3. abnormal organ images in which the organs are large in size or in wrong position.

Thus, all the images are categorized into 9 classes:

| head normal | head abnormal present | head abnormal organ |
| chest normal | chest abnormal organ | chest abnormal present |
| abdominal normal | abdominal abnormal organ | abdominal abnormal present |

For every iteration of MAML training (line 3 in Alg.1), 5 tasks are sampled per iteration. For each task, we randomly select 3 classes (from 9 classes). For each class, we randomly select 6 images in which 3 images are used for updating task models and the remaining 3 images are used for updating meta-model.

Alg-1

Denoising Auto Encoder -- CDAE

The encoder maps an image xx', which is the noisy version of the original image xx, to a latent representation zz which retains useful amount of information. The decoder transforms zz to the output yy. The training algorithm aims to minimize the reconstruction error between yy and the original image xx as follows

Lrec=xy22L_{rec} = \left \| x-y \right \|_2^2

In our design, the encoder is a stack of convolutional layers; each of them is followed by a max pooling layer. The decoder is a stack of deconvolutional and convolutional layers. The noisy version xx' is achieved by adding Gaussian noise to the original image xx.

To train CDAE, we collect 11,77911,779 unlabeled images available online which are brain MRI images, chest X-ray images and CT abdominal images. The dataset is split into train set with 9,4239,423 images and test set with 2,3562,356 images. We use Gaussian noise to corrupt the input images before feeding them to the encoder.

Our VQA framework

After training MAML and CDAE, we use their trained weights to initialize the MEVF image feature extraction component in the VQA framework. We then finetune the whole VQA model using the training set of VQA-RAD dataset.

To train the proposed model, we introduce a multi-task loss func-tion to incorporate the effectiveness of the CDAE to VQA. Formally, our lossfunction is defined as follows:

L=α1Lvqa+α2LrecL = \alpha_1 L_{vqa} + \alpha_2 L_{rec}

where LvqaL_{vqa} is a Cross Entropy loss for VQA classification and LrecL_{rec} stands for the reconstruction loss of CDAE . The whole VQA model is finetuned in an end-to-end manner.

Results

Tab-1

Table 1: VQA results on VQA-RAD test set. All reference methods differ at the image feature extraction component. Other components are similar. The Stacked Attention Network (SAN) is used as the attention mechanism in all methods (Source).

Table 1 presents VQA accuracy in both VQA-RAD open-ended and close-ended questions on the test set. The results show that for both MAML and CDAE, by firstly pretraining then finetuning, the finetuning significantly improves the performance over the training from scratch using only VQA-RAD.

In addition, the results also show that our pretraining and finetuning of MAML and CDAE give better performance than the finetuning of VGG-16 which is pretrained on the ImageNet dataset. Our proposed image feature extraction MEVF which leverages both pretrained weights of MAML and CDAE, then finetuning them give the best performance. This confirms the effectiveness of the proposed MEVF for dealing with the limitation of labeled training data for medical VQA.

Tab-2

Table 2: Performance comparison on VQA-RAD test set (Source).

Table 2 presents comparative results between methods. Note that for the image feature extraction, the baselines use the pretrained models (VGG or ResNet) that have been trained on ImageNet and then finetune on the VQA-RAD dataset. For the question feature extraction, all baselines and our framework use the same pretrained models (i.e., Glove) and finetuning on VQA-RAD. The results show that when BAN or SAN is used as the attention mechanism in our framework, it significantly outperforms the baseline frameworks BAN and SAN. Our best setting, i.e. the one with BAN as the attention, achieves the state-of-the-art results and it significantly outperforms the best baseline framework BAN, i.e., the improvements are 16.3%16.3\% and 8.6%8.6\% on open-ended and close-ended VQA, respectively.

Conclusion

In this paper, we proposed a novel medical VQA framework that leverages the meta-learning MAML and denoising auto-encoder CDAE for image feature extraction in order to overcome the limitation of labeled training data. Specifically, CDAE helps to leverage information from the large scale unlabeled images, while MAML helps to learn meta-weights that can be quickly adapted to the VQA problem. We establish new state-of-the-art results on VQA-RAD dataset for both close-ended and open-ended questions.

Open Source

🐱 Github: https://github.com/aioz-ai/MICCAI19-MedVQA

Data Augmentation for Colon Polyp Detection: A systematic Study

Colorectal cancer (CRC)♋, also known as bowel cancer or colon cancer, is a cancer development from the colon or rectum called a polyp. Detecting polyps is a common approach in screening colonoscopies to prevent CRC at an early stage. Early colon polyp detection from medical images is still an unsolved problem due to the considerable variation of polyps in shape, texture, size, color, illumination, and the lack of publicly annotated datasets. At AIOZ, we adopt a recently proposed auto-augmentation method for polyp detection. We also conduct a systematic study on the performance of different data augmentation methods for colon polyp detection. The experimental results show that the auto-augmentation achieves the best performance comparing to other augmentation strategies.

Introduction

Colorectal cancer (CRC) is the third-largest cause of worldwide cancer deaths in men and the second cause in women, with the number of patients, died each year up to 700,000 [1]. Detection and removal of colon polyps at an early stage will reduce the mortality from CRC. There are several methods for colon screening, such as CT colonography or wireless capsule endoscopy, but the gold standard is colonoscopy [2].

The colonoscopy is performed by an experienced doctor who uses a colonoscope to screen and scan for abnormalities such as intestinal signs, symptoms, colon cancer, and polyps. Abnormal polyps can be removed, and small amounts of tissue can be detached for analysis during the colonoscopy. However, the most crucial drawback of colonoscopy is polyp miss rate, especially with polyp more diminutive than10mm. Several factors cause the miss rate. They are both subjective factors such as bowel preparation, the specific choice of an endoscope, video processor, clinician skill, and objective factors such as polyp appearance and camera movement condition. For these reasons, automatic polyp detection is a potential approach to assist clinicians in improving the sensitivity of the diagnosis.

Previous research shows that automatic polyp detection using deep learning-based methods outperforms hand-craft-based methods demonstrated by both top two results in the MICCAI 2015 challenge [3]. For deep learning-based approaches and model architectures, data augmentation is also a critical factor in making significant improvements due to the lack of annotated data. The recent work [4]shows that learning an optimal policy from data for auto augmentation instead of hand-crafted defining data augmentation strategies can generalize objects better. Thus, studying auto augmentation for polyp detection problems is necessary. In this research, we adapt Faster R-CNN [5] together withAutoAugment [6] to detect polyp from colonoscopy video frames. Besides, we also evaluate traditional data augmentation [7] to see the effectiveness of different augmentation strategies.

Methodology

1. Polyp Detector

Thanks to the power of deep learning, recent works [5, 12,11] show that deep-based detection methods give impressive detection performance. In this work, we use the Faster RCNN object detector [5] with Resnet101 [13] backbone pre-trained on COCO dataset. Our experiments show that this architecture gives a competitive performance on the Polyp detection problem. The experimental setting for the detector is set as follows. The network is trained using stochastic gradient descent (SGD)with 0.9 momentum; learning rate with initial value is set to3e-4 and will be decreased to 3e-5 from the iteration 900k.The number of anchor boxes per location used in our model is 12 (4 scales, i.e.,64×64, 128×128, 256×256, 512×512 and 3 aspect ratios, i.e.,1 : 2,1 : 1,2 : 1).

2. Data Augmentation

autoaugment_figure_small

Fig. 1. Example of applying learned augmentation policies tocolonoscopy image.

Data augmentation can be split into two types: self-defined data augmentation (a.k.a traditional augmentation) and auto augmentation [6]. In this study, we adopt an automated data augmentation approach for object detection, i.e., Auto-augment [6], which finds optimal data augmentation policies during training. In Auto-augment, an augmentation policy consists of several sub-policies; each sub-policy consists of two operations. Each operation is an image transformation containing two parameters: probability and the magnitude of the shift. There are three types of transformations used in Auto-augment for object detection [4], which are

  • Color operations: distort color channels without impacting the locations of the bounding boxes
  • Geometric operations: geometrically distort the image, which correspondingly alters the location and size of the bounding box annotations
  • Bounding box operations: only distort the pixel content contained within the bounding box annotations. One of the essential conclusions in [4] is that the learned policy found on COCO can be directly applied to other detection datasets and models to improve predictive accuracy. Hence, in this study, we apply the learned policies from [4]to augment data when training the detector in Sec. 3.1. The learned policed we use for training our detector is summarized in Table 1. In Table 1, each operation is a triple which describes the transformation, the probability, and the magnitude of the transformations. Due to the space limitation, we refer the reader to [4] for detail on the descriptions of trans-formation. Fig. 1 showed augmented examples when applying the learned augmentation policy on a polyp image from the training dataset.

tab-1

Table 1. Sub-policies and operations used in our experiment.

In addition to auto-augmentation, we also investigate the effect of traditional augmentation and the combination of traditional and automatic augmentation. We randomly apply several transformations to the image for traditional data augmentation, such as rotation, mirroring, sheering, translation, and zoom. We propose different strategies to combine those data augmentation types, i.e., (1) firstly, the detector is trained with Auto-augment; after that, it is trained with the traditional data augmentation; (2) training with the traditional augmentation, then with auto augmentation; (3) training with AutoAugmenton the original data and the data generated by traditional augmentation. All augmentation strategies are evaluated with the same model architecture and training configuration. This allows us to explore which data augmentation method is suitable for the polyp detection problem.

Experiments

We use CVC-ClinicDB [14] for training and ETIS-Larib[15] for testing. This allows us to make a fair comparison with MICCAI2015 challenge results which are reported on the same dataset. The CVC-CLINIC database contains 612polyp image frames of 31 unique polyps from 31 different colonoscopy videos. The ETIS-LARIB dataset contains 196high resolution image frames of 44 different polyps.

fp_figure

Fig. 2. Examples of false positive detection on testing dataset. Green boxes and blues boxes are ground truths and predictions, respectively.

fn_figure

Fig. 3. Examples of false-negative detection on the testing dataset. Green boxes and blues boxes are ground truths and predictions, respectively.

Fig. 2 and Fig. 3 visualize several failed results from our model in the testing dataset in which the blue boxes are the predicted locations, and green boxes are ground truths. These false-positive samples (Fig. 2) caused by a shortcoming in bowel preparation (i.e., leftovers of food and fluid in the colon), while false negative (Fig. 3) samples are caused by the variations of polyp type and appearance (i.e., small polyp, flat polyp, similarities of polyp and colon vein)

tab-3

Table 2. Comparison among traditional data augmentation (TDA), auto augmentation (AA) and their combinations.
Table 2 shows the comparative results between different augmentation strategies. The results show that the third combination method (AA-TDA-3) achieves higher performance in Precision than AutoAugment, i.e., 75.90% and 74.51%, respectively. However, overall, Auto-augment (AA) achieves the best results because of its performance in covering polyp miss rate (i.e., 152) with an acceptable false-positive rate (i.e., 52). The competitive performance of auto augmentation (AA) confirms the transferable learned data augmentation policies on the COCO dataset [4].

tab-4

Table 3. Comparative results between our model and the state of the art.
Table 3 presents the comparative results between the auto augmentation in our model and other state-of-the-art results. Among compared methods, CUMED, OUR, and UNS- UCLAN are end-to-end deep learning-based approaches. The results show that compared to methods from MICCAI challenge [3], auto augmentation achieves better performance on all metrics. Comparing to the recent method [10], auto augmentation also achieves better performance on all metrics but FP. These results confirm the effectiveness of auto augmentation for polyp detection problems.

Conclusion

This study adopts a deep learning-based object detection method with auto data augmentation for polyp detection problems. Different augmentation strategies are evaluated. The experimental results show that the learned auto augmentation policies learned from the general object detection dataset are well transferred to the polyp detection problem. Although auto augmentation achieves competitive results, it still has a high FP compared to the state of the art. This weakness can be improved by several post-processing, such as false-positive learning.

Open Source

🍅 Github: https://github.com/aioz-ai/polyp-detection

🍓 Blog post: https://ai.aioz.io/blog/polyp-detection

Acknowledgements

This research was conducted by Phong Nguyen, Quang Tran, Erman Tjiputra, and Toan Do. We’d like to give special thanks to the other AIOZ AI team members for their supports and feedbacks.

🎉 All the above contributions were incredibly enabling for this research. 🎉

Reference

[1] Hamidreza Sadeghi Gandomani, Mohammad Aghajani, et al., “Colorectal cancer in the world: incidence, mortality and risk factors,”Biomedical Research and Therapy, 2017.

[2] Florence B ́enard, Alan N Barkun, et al., “Systematic review of colorectal cancer screening guidelines for average-risk adults: Summarizing the current global recommendations,”World journal of gastroenterology, 2018.

[3] Jorge Bernal, Nima Tajkbaksh, et al., “Comparative validation of polyp detection methods in video colonoscopy: results from the miccai 2015 endoscopic vision challenge,”IEEE Transactions on Medical Imaging, pp. 1231–1249, 2017.

[4] Barret Zoph, Ekin D Cubuk, et al., “Learning data augmentation strategies for object detection,”arXiv, 2019.

[5] Shaoqing Ren, Kaiming He, Ross Girshick, and Jian Sun, “Faster R-CNN: Towards real-time object detection with region proposal networks,” inNIPS, 2015.

[6] Ekin D Cubuk, Barret Zoph, et al.,“Autoaugment: Learning augmentation strategies from data,” in CVPR, 2019.

[7] Younghak Shin, Hemin Ali Qadir, et al., “Automatic colon polyp detection using region based deep cnn and post learning approaches,”IEEE Access, 2018

[8] Yangqing Jia, Evan Shelhamer, et al., “Caffe: Convolutional architecture for fast feature embedding,” in ACMMM, 2014.

Introduction to Federated Learning

Federated Learning: machine learning over a distributed dataset, where user devices (e.g., desktop, mobile phones, etc.) are utilized to collaboratively learn a shared prediction model while keeping all training data locally on the device. This approach decouples the ability to do machine learning from storing the data in the cloud.

Conceptually, federated learning proposes a mechanism to train a high-quality centralized model. Simultaneously, training data remains distributed over many clients, each with unreliable and relatively slow network connections.

The idea behind federated learning is as conceptually simple as its technologically complex. Traditional machine learning programs relied on a centralized model for training in which a group of servers runs a specific model against training and validation datasets. That centralized training approach can work very efficiently in many scenarios. Still, it has also proven to be challenging in use cases involving a large number of endpoints using and improving the model. The prototypical example of the limitation of the centralized training model can be found in mobile or internet of things(IoT) scenarios. The quality of a model depends on the information processed across hundreds of thousands or millions of devices. Each endpoint can contribute to a machine learning model's training in its own autonomous way in those scenarios. In other words, knowledge is federated.

Blockchain: large, distributed dataset, where no-one can edit/delete an old entry, nor fake a new entry. The data is enforced by fundamental limits of computations (i.e., Proof of Work).

Smart Contract: dataset stored on the Blockchain, which includes: Data (i.e., ledgers, events, statistics), State (today's ledger, today's events), Code (rules for changing state).

Math Examples

Fundamental Theorem of Calculus Let f:[a,b]Rf:[a,b] \to \R be Riemann integrable. Let F:[a,b]RF:[a,b]\to\R be F(x)=axf(t)dtF(x)= \int_{a}^{x}f(t)dt. Then FF is continuous, and at all xx such that ff is continuous at xx, FF is differentiable at xx with F(x)=f(x)F'(x)=f(x).

Lift(LL) can be determined by Lift Coefficient (CLC_L) like the following equation.

L=12ρv2SCLL = \frac{1}{2} \rho v^2 S C_L

References

[1] J.Rodriguez. Whats New in Deep Learning Research: Understanding Federated Learning.

Graph-based Person Signature for Person Re-Identifications

Motivation

Person re-identification (ReID) aims to retrieve a particular person image in a collection of images captured by multiple cameras from various viewpoints across time.

The challenges of the person ReID task come from significant variations of human attributes such as poses, gaits, clothes, as well as challenging environmental settings like illumination, complex background, and occlusions. With the rise of deep learning, most of the recent studies utilize Convolutional Neural Network (CNN) to tackle the person ReID problem.

Recently, attribute-based methods have shown great success in providing semantic features for the deep network. Unlike the person identity label, which offers only coarse information to identify one identity among all other person identities, the attributes are the detailed descriptions that are highly intuitive and mostly unchanged between images captured from different cameras. Therefore, they can be used to explicitly guide the model to learn a robust person representation by defining human characteristics.

In this work, we propose to utilize the person attribute information with its associated body part to encode the visual person signature in one unified framework.

  • We hypothesize that the detailed person descriptions (attributes labels) can be integrated with visual features (body parts and global features) to create a unique signature for a particular person.
  • Since both body parts and attributes provide local representations, by linking them together, the network can have a better understanding of the relationship between visual features and attribute descriptions.
  • Although previous works have investigated how person identity, body parts, and attributes benefit the task of person ReID, our key difference is that we utilize Graph Convolutional Networks (GCN) to effectively construct and model the correlation between attributes and body parts with global features. In particular, we treat body part regions and attributes as nodes in a graph and utilize a GCN to learn the topological structure of a person's signatures. The GCN propagates messages on a graph structure. After message traversal on the graph, the node's final representations are obtained from its data and from other node's information. Figure 1 shows the effectiveness of our approach.

Fig-1

Figure 1: The effectiveness of our GPS in improving retrieval results on Market-1501 dataset. The details are presented in the text (Source).

Methodology

Fig-2

Figure 2: Illustration of our proposed framework including two branches: (1) global branch which extracts person global features; (2) GPS branch which performs reasoning the person attributes and body parts using GCN. The details are presented in the text (Source).

The proposed framework is presented in Figure 2.

  • We denote II is a probe person image. This probe image II is first passed through a backbone CNN to get the feature map F\mathbf{F}.
  • By utilizing a human parsing pretrained model, we extract the body part masks to obtain the visual features of each part.
  • The person attributes are then represented by a lookup word embedding.
  • Given body part features and attribute features, we construct the Graph-based Person Signature which includes attribute nodes and body part nodes conditioned on the correlation matrix. We employ the GCN for reasoning on the person signature graph and encoding the graph into more representativeness features.

Our proposed method is a multi-branch multi-task framework for person ReID, where the main branch performs the verification task by optimizing two well-known loss functions: Triplet loss and Center loss. The auxiliary branch performs reasoning on the proposed person signature graph and solves the attribute recognition as well as the person identity classification tasks.

Experimental Results

Ablation Study

Loss Contribution. In Table 2, we show the contribution of each loss to the final performance on the Market1501 dataset. The person ID classification loss, triplet loss, center loss, and attribute recognition loss are denoted as Lid\mathbf{\mathcal{L}_{id}}, Ltriplet\mathbf{\mathcal{L}_{triplet}}, Lcenter\mathbf{\mathcal{L}_{center}}, and La\mathbf{\mathcal{L}_{a}}, respectively. The performance is improved when we incorporate all losses to the framework, which justifies the effectiveness of our proposed method. By using only Lid\mathbf{\mathcal{L}_{id}}, we still achieve comparative results with other mask-guided and attribute-based methods. While the triplet loss Ltriplet\mathbf{\mathcal{L}_{triplet}} demonstrates its capability on improving the performance, the center loss Lcenter\mathbf{\mathcal{L}_{center}} shows a slight impact on the performance. Notably, the attribute loss La\mathbf{\mathcal{L}_{a}} shows stability when being incorporated with other loss functions.

Tab-2

Table 2: The contribution of losses to the performance of person ReID task on Market1501 dataset. Note that the experiments are conducted with ResNet-50 as backbone CNN network (Source).

Model Interpretability. In this section, we conduct cross-dataset experiments to evaluate the effectiveness of GPS. The model is trained on the source dataset and test directly on the target dataset without finetuning. As shown in Table 3, our GPS archives a significant improvement over the Bag-of-Tricks baseline (BoT). This demonstrates the interpretability of our proposed method as well as confirms the effectiveness of learning the attributes for the person ReID task.

Tab-3

Table 3: The transferable ability of our GPS evaluated on cross-dataset (Source).

Training Parameters. We also provide the number of training parameters of our GPS and the baseline BoT in Table 4 to show the complexity of each method. Overall, our GPS slightly increases about 3M parameters in comparison with the baseline BoT while achieving much better performance.

Tab-4

Table 4: The number of parameters of our GPS in comparision with the baseline BoT on Market1501 and DukeMTMC-ReID datasets using ResNet-50 as the backbone network. #nParam indicates the number of parameters and 1K=1000 (Source).

Comparison to the state of the art

GPS vs. Baseline. The last two rows of Table 5 show the result of our GPS when being integrated into the baseline BoT. The results clearly show that our GPS significantly improves the performance of BoT in both Market-1501 and DukeMTMC-ReID dataset. This demonstrates the effectiveness of our GPS and confirms the usefulness of learning the attributes in the ReID task.

Tab-5

Table 5: Comparison with state-of-the-art methods on Market-1501 and DukeMTMC-ReID datasets. The cyan and yellow boxes are the best results corresponding to mask-guided/attribute-based and other approaches, respectively. Note that no post-processing is applied to our method (Source).

Evaluation on Market-1501. We evaluate our GPS with other methods on Market-1501 dataset in Table 5. The results show that our method outperforms the state-of-the-art attribute-based methods AANet that use attribute and body part information in all evaluation metrics. Specifically, we outperforms AANet by 5.3% and 1.3% at mAP and R-1, respectively. Our GPS also outperforms the state-of-the-art mask-guided methods, and especially, we outperform P2^2-Net by 2.2% at mAP. At the same time, we also get comparative results when comparing with other recent ReID approaches.

Evaluation on DukeMTMC-ReID. Table 5 also summaries the results of our GPS and other methods on DukeMTMC-ReID dataset. Our GPS significantly outperforms other attribute-based methods in all metrics. Specifically, our method outperforms the recent state-of-the-art attribute-based method AANet by 6.1% at mAP and 1.8% at R-1. In addition, we also outperforms ADPR by 9.0%, 3.9%, 2.8%, 2.0% at mAP, R-1, R-5, R-10, respectively. Moreover, our GPS outperforms the state-of-the-art mask-guided method P2^2-Net by 4.9%, 1.7%, 2.1%, 1.7% at mAP, R-1, R-5, R-10, respectively. Besides, we also achieve comparative results with other ReID approaches.

Attributed-based and Mask-guided vs. Other approaches. From Table 5, we notice that although our GPS shows a definite improvement over mask-guided and attributed-based methods, it achieves competitive results with methods from other approaches and particularly being outperformed by st-ReID method. Note that the results of st-ReID also completely dominate all methods from all other approaches. The effectiveness of st-ReID comes from the fact that it also uses the spatial-temporal information (i.e., the spatial map of camera setting and temporal information from video timestamp) into the network. This extra information allows the network to encode the person identity from multiple viewpoints, which significantly reduces the effect of different poses, viewpoints, or ambiguity challenges. From experiments, we have observed that our GPS, as well as other attribute-based and mask-guided methods, suffers from the fact that the pretrained body part network cannot provide adequate segmentation masks, so the retrieval results are also affected.

Quantity Visualization

Fig-5

Figure 3: Top 5 retrieval results of some queries on Market-1501 dataset. Note that the green/red boxes denote true/false retrieval results, respectively.

We present some retrieval examples with five retrieved images for each query in Figure 3. As in the visualization, our GPS obtained better retrieval results than the baseline. In the first row of Figure 3, the baseline gets the false retrieval result at Rank-5 due to the similarity of gender, wearing a hat, etc., except the color of the clothes. By leveraging our GPS, the extracted features are more robust to attribute and body part information, then, lead to better retrieval results for ReID model. In the second row, the model with our GPS gives better results by extracting more information about the relationship between `backpack' attribute and this person identity, thereby eliminating false cases. We also show an example that our GPS does not yet produce entirely correct retrieval results in the third line of the Figure 3. In this case, the lower body of the probe image is partly covered by the bicycle. Thus, the extracted features (i.e., the color of the pants) are not fully captured, which results in the feature misalignment between the probe image and retrieval results.

Conclusion

This paper proposes Graph-based Person Signature (GPS) that effectively captures the dependencies of person attributes and body parts information. We utilize the GCN on the GPS to propagate the information among nodes in the graph and integrate the graph features into a novel multi-branch multi-task network. The experimental results on benchmark datasets confirm the effectiveness of our GPS and demonstrate that our GPS performs better than recent state-of-the-art attribute-based and mask-guided ReID methods.

Open Source

🐱 Github: https://github.com/aioz-ai/CVPRW21_GPS

Compact Trilinear Interaction for Visual Question Answering

In Visual Question Answering (VQA), answers have a great correlation with question meaning and visual contents. Thus, to selectively utilize image, question and answer information, we propose a novel trilinear interaction model which simultaneously learns high level associations between these three inputs. In addition, to overcome the interaction complexity, we introduce a multimodal tensor-based PARALIND decomposition which efficiently parameterizes trilinear interaction between the three inputs. Moreover, knowledge distillation is applied in Free-form Opened-ended VQA. It is not only for reducing the computational cost and required memory but also for transferring knowledge from trilinear interactionmodel to bilinear interaction model. The extensive experiments on benchmarking datasets TDIUC, VQA-2.0, and Visual7W show that the proposed compact trilinear interaction model achieves state-of-the-art results on all three datasets.

For free-form opened-ended VQA task, CTI achieved 67.4 on VQA-2.0 and 87.0 on TDIUC dataset in VQA accuracy metric.

For multiple choice VQA task, CTI achieved 72.3 on Visual7W dataset in MC-VQA accuracy metric.

Compact Trilinear Interaction in VQA.

Let M={M1,M2,M3}M = \{M_1, M_2, M_3\} be the representations of three inputs. MtRnt×dtM_t \in \textbf{R}^{n_t \times d_t}, where ntn_t is the number of channels of the input MtM_t and dtd_t is the dimension of each channel.
For example, if M1M_1 is the region-based representation for an image, then n1n_1 is the number of regions and d1d_1 is the dimension of the feature representation for each region. Let mteR1×dtm_{t_e} \in \textbf{R}^{1 \times d_{t}} be the ethe^{th} row of MtM_t, i.e., the feature representation of ethe^{th} channel in MtM_t, where t{1,2,3}t \in \{1, 2, 3\}.

The input for training VQA is set of (V,Q,A)(V,Q,A) in which VV is an image representation; VRv×dvV \in \textbf{R}^{v \times d_v} where vv is the number of interested regions (or bounding boxes) in the image and dvd_v is the dimension of the representation for a region; QQ is a question representation; QRq×dqQ \in \textbf{R}^{q \times d_q }
where qq is the number of hidden states and dqd_q is the dimension for each hidden state.
AA is an answer representation; ARa×daA \in \textbf{R}^{a \times d_a}
where aa is the number of hidden states and dad_a is the dimension for each hidden state.

We firstly compute the attention map M\mathcal{M} as follows:

M=r=1RGr;VWvr,QWqr,AWar\mathcal{M} = \sum^R_{r=1} {\llbracket \mathcal{G}_r; V W_{v_r}, Q W_{q_r}, A W_{a_r}\rrbracket}

Then the joint representation zz is computed as follows:

zT=i=1vj=1qk=1aMijk(ViWzvQjWzqAkWza)z^T= \sum_{i=1}^{v}\sum_{j=1}^{q}\sum_{k=1}^{a} \mathcal{M}_{ijk}\left( V_{i}W_{z_v} \circ Q_{j}W_{z_q} \circ A_{k}W_{z_a}\right)

where Wvr,Wqr,WarW_{v_r},W_{q_r}, W_{a_r} and Wzv,Wzq,WzaW_{z_v},W_{z_q}, W_{z_a} are learnable factor matrices; each Gr\mathcal{G}_r is a learnable Tucker tensor.

Integrate CTI into different VQA task

For multiple choice VQA

Fig1

Figure 1. The model when CTI is applied to MC VQA.

Each input question and each answer are trimmed to a maximum of 12 words which will then be zero-padded if shorter than 12 words. Each word is then represented by a 300-D GloVe word embedding. Each image is represented by a 14×14×204814 \times 14 \times 2048 grid feature (i.e., 196196 cells; each cell is with a 20482048-D feature), extracted from the second last layer of ResNet-152 which is pre-trained on ImageNet.

Input samples are divided into positive samples and negative samples. A positive sample, which is labelled as 11 in binary classification, contains image, question and the right answer. A negative sample, which is labelled as 00 in binary classification, contains image, question, and the wrong answer. These samples are then passed through CTI to get the joint representation zz. The joint representation is passed through a binary classifier to get the prediction. The Binary Cross Entropy loss is used for training the model.

For free-form opened-ended VQA

Fig2

Figure 2. The model when CTI is applied to FFOE VQA.

Unlike MC VQA, FFOE VQA treats the answering as a classification problem over the set of predefined answers. Hence the set possible answers for each question-image pair is much more than the case of MC VQA. For each question-image input, the model takes every possible answers from its answer list to computed the joint representation, causes high computational cost.

In addition, CTI requires all three V,Q,AV, Q, A inputs to compute the joint representation. However, during the testing, there are no available answer information in FFOE VQA. To overcome these challenges, we propose to use Knowledge Distillation to transfer the learned knowledge from a teacher model to a student model.

The loss function for the student model is defined as:

LKD=αT2LCE(QSτ,QTτ)+(1α)LCE(QS,ytrue)\mathcal{L}_{KD} = \alpha T^2 \mathcal{L}_{CE}(Q^\tau_S, Q^\tau_T) + (1-\alpha)\mathcal{L}_{CE}(Q_S,y_{true})

where LCE\mathcal{L}_{CE} stands for Cross Entropy loss; QSQ_S is the standard softmax output of the student; ytruey_{true} is the ground-truth answer labels;
α\alpha is a hyper-parameter for controlling the importance of each loss component; QSτ,QTτQ^\tau_S, Q^\tau_T are the softened outputs of the student and the teacher using the same temperature parameter TT, which are computed as follows:

Qiτ=exp(li/T)iexp(li/T)Q^\tau_i = \frac{exp(l_i/T)}{\sum_{i} exp(l_i/T)}

where for both teacher and the student models, the logit ll is the predictions outputted by the corresponding classifiers.

Results

Tab-1

Table 1. Performance of CTI and BAN2, SAN in VQA-2.0 validation set and test-dev set. BAN2-CTI and SANCTI are student models trained under the teacher model.

To further evaluate the effectiveness of CTI, we conduct a detailed comparison with the current state of the art. For FFOE VQA, we compare CTI with the recent state-of-the-art methods on TDIUC and VQA-2.0 datasets. For MC VQA, we compare with the state-of-the-art methods on Visual7W dataset.

Tab-2

Table 2. Performance comparison between different approaches with different evaluation metrics on TDIUC validation set. BAN2-CTI and SAN-CTI are the student models trained under compact trilinear interaction teacher model.

Regarding FFOE VQA, Table 1 and Table 2 show comparative results on VQA-2.0 and TDIUC respectively. Specifcaly, Table 1 shows that distilled student BAN2-CTI outperforms all compared methods over all metrics by a large margin, i.e., the model outperforms the current state-of-the-art QTA on TDIUC by 3.4%3.4\% and 5.4%5.4\% on Ari and Har metrics, respectively. The results confirm that trilinear interaction has learned informative representations from the three inputs and the learned information is effectively transferred to student models by distillation.

Tab-3

Table 3. Performance comparison between different approaches on Visual7W test set. Both training set and validation set are used for training. All models but CTIwBoxes are trained with same image and question representations. Both train set and validation set are used for training. Note that CTIwBoxes is the CTI model using Bottom-up features. instead of grid features for image representation.

Regarding MC VQA, Table 3 shows that the CTI outperforms compared methods by a noticeable margin. This model outperforms the current state-of-the-art STL by 1.1%. Again, this validates the effectiveness of the proposed joint presentation learning, which precisely and simultaneously learns interactions between the three inputs. We note that when comparing with other methods on Visual7W, for image representations, we used the grid features extracted from ResNet-512 for a fair comparison. Our proposed model can achieve further improvements by using the object detection-based features used in FFOE VQA. With new features, the model denoted as CTIwBoxes in Table 3 achieve 72.3% accuracy with Acc-MC metric which improves over the current state-of-the-art STL 4.1%.

Conclusion

A novel compact trilinear interaction is introduced to simultaneously learns high level associations between image, question, and answer in both MC VQA and FFOE VQA. In addition, knowledge distillation is the first time applied to FFOE VQA to overcome the computational complexity and memory issue of the interaction. The extensive experimental results show that these models achieve the state-of-the-art results on three benchmarking datasets.

Open Source

Github: https://github.com/aioz-ai/ICCV19_VQA-CTI