19 posts tagged with "ai"

View All Tags

Music-Driven Group Choreography (Part 3)

This is the final part of the series group dance choreography, In this part, we will provide detailed analyses of our proposed group dance generation method.


AIOZ-GDANCE Statistics

Figure 1. Distribution (%) of music genres (a) and dance styles (b) in our dataset.

In Figure 1, we show the distribution of music genres and dance styles in our dataset. As illustrated in Figure 1 (Left), Pop and Electronic are popular music genres while other music genres nearly share the same distribution. Meanwhile, on the right of Figure 1, Zumba, Aerobic, and Commercial are the dominant dance styles.

Figure 2. The correlation between dance styles and number of dancers (a); and between dance styles and music genres (b).

Figure 2 (Left) shows the number of dancers in each dance style. Naturally, we see that Zumba, Aerobic, and Commercial have more dancers. On the right of this figure, we illustrate the correlation between music genres and dance styles.

Evaluation Metrics

Similar to prior works on single-dance generation, we evaluate the generated motion quality by calculating the distribution distance between the generated and the ground-truth motions using Frechet Inception Distance (FID)[1, 2]. To evaluate how well the generated 3D motion correlates to the input music, we use the Motion-Music Consistency metric (MMC) [2,3]. We also evaluate our model's ability to generate diverse dance motions when given various input music by measuring Generation Diversity (GenDiv) [2,3].

To evaluate the group dancing quality, we propose three new metrics: Group Motion Realism (GMR), Group Motion Correlation (GMC), and Trajectory Intersection Frequency (TIF). Detailed calculations of these metrics are described as follows:

Group Motion Realism (GMR). To calculate the realism between generated and ground-truth group motion, we need to find a single unified representation for all dancers' motions in the scene. Based on the kinetic features of a single motion sequence [4], we propose to calculate Group Motion Realism (GMR), smaller is better. Specifically, for each entity, we compute the velocity of each element jj of the pose vector: vtn=yt+1nytnΔtv^n_t = \frac{y^n_{t+1} - y^n_t}{\Delta t} where Δt\Delta t is the time period between two consecutive frames. Note that the pose vector of each entity at each frame consists of the root orientation, root position and joint angles. The group kinetic features of a sequence is approximated by taking the logarithm of the total kinetic energy of all group entities as:

ej=log(1+1T1Nt=1Tn=1Nmj(vt,jn)2)e_j = \log \left(1 + \frac{1}{T}\frac{1}{N} \sum_{t=1}^T \sum_{n=1}^N m_j (v^n_{t,j})^2\right)

where mjm_j is the moment of inertia or mass of each joint. We assume that mjm_j is constant with respect to time and entity. Then, we split the sequence into smaller chunks and calculate the features of these chunks. This process is identical for both the generated and ground-truth sequences. Finally, we utilize these sets of features (from generated and ground-truth group dance) to calculate the GMR using the standard FID formulation as in [1].

Group Motion Correlation (GMC). We also evaluate the synchrony and the correlation between dancers within the generated group. We assume that the correlation of movements between individuals is likely to reflect their interaction in the choreography. For every pair of motions within a group, we first align the two motion sequences using Dynamic Time Warping algorithm based on the Euclidean distance in the joint position space (obtained by SMPL joint regressor). We then calculate the mean cross-correlations between the time-aligned motion pairs using the kinetic features [4]. The generated group motion correlation degree is then calculated as the average of all motion pairs.

Trajectory Intersection Frequency (TIF). For the generated group sequences, the intersection rate is calculated over all FF frames as:

TIR=Fi,j:ijI[intersect(M(yi),M(yj))]F,\text{TIR} = \frac{\sum_{F}\sum_{i,j : i\neq j} \mathbb{I}[\text{intersect}(M(y^i),M(y^j))]}{F},

where MM is the SMPL skinning function [5] which can output a 6890-vertices human mesh from the input pose parameters yy. intersect(x,y)\text{intersect}(x,y) is a function that returns 1 if the two meshes are intersect with each other and 0 otherwise. For TIF, smaller value is better and indicates less intersection of the generated group.

Cross-entity Attention Analysis

We compare our method with FACT [2]. FACT is a recent state-of-the-art method designed for single dance generation, thus giving our method an advantage. However, it is still the closest competing method as we propose a new group dance dataset that is not available for benchmarking before. We also analyse our method with and without using Cross-entity Attention. We train all methods with mini-batch containing all dancers within the group instead of sampling each dancer independently as in FACT’s original implementation.

Figure 3. Comparison between FACT and our GDanceR. Our method handles better the consistency and cross-body intersection problem between dancers.

Table 1 shows the method comparison between the baseline FACT[2] and our proposed GDanceR with and without Cross-entity Attention. The results show that GDanceR, especially with the Cross-entity Attention, outperforms the baseline by a large margin in all metrics. In Figure 3, we also visualize the example outputs of FACT and GDanceR. It is clear that FACT does not handle well the intersection problem. This is understandable as FACT is not designed for group dance generation, while our method with the Cross-entity Attention can deal with this problem better.

Table 1. Generation results comparison on AIOZ-GDANCE dataset. w/o CA denotes without using Cross-entity Attention.

Number of Dancers Analysis.

Table 2 demonstrates the generation results of our method when we want to generate different numbers of dancers. In general, the FID, GMR, and GMC metrics do not show much correlation with the numbers of generated dancers since the results are varied. On the other hand, MMC shows its stability among all setups (0.248\sim 0.248), which indicates that our network is robust in generating motion from given music regardless of the changing of initial positions. The generation diversity (GenDiv) decreases while the intersection frequency (TIF) increases when more dancers are generated. These results show that dealing with the collision during the group generation process is worth further investigation.

Table 2. Performance of our proposed method when increasing the number of generated dancers.

Dance Style Analysis

Figure 4. Examples of generated group motions from our method.

Different dance styles exhibit different challenges in group dance generation. As shown in Table 3, Aerobic and Zumba are quite similar for generating choreography as they usually focus on workout and sporty movements. Besides, while Commercial and Irish are easier for the model to reproduce the motions, Bollywood and Samba contain highly skilled movements that are challenging to capture and represent accurately. In Figure 4, we show the generated results of GDanceR with different dance styles. Our Supplementary Material and Demonstration Video also provide more examples.

Table 3. The results of different dance styles. These results are obtained by training the model on each dance style.

Ablation on Latent Motion Fusion.

We investigate different fusion strategies between the local motion hih^i and global-aware motion gig^i to obtain the final motion representation ziz^i. Specifically, we experiment with three settings: (i) No Fusion: the final motion is the global-aware motion obtained from our Cross-entity Attention (zi=giz^i = g^i); (ii) Concatenate: the final motion is the concatenation of the local and global-aware motion (zi=[hi;gi]z^i = [h^i; g^i]); (iii) Add: the final motion is the addition between local and global (zi=hi+giz^i = h^i + g^i). Table 4 summarizes the results. We find that fusing the motion by adding both the local and global motion features achieves the best results. In this strategy, the global information between entities is encoded to the local motion in an effective way so that the final motion retain the comprehensive information of their own past motion as well as the motion of every other entity. While in the concatenation, the model is prone to overfitting due to the redundant information of both the local and global representation. On the other hand, No Fusion can degrade the amount of information of the past motion, leading to insufficient input information and the Decoder may fail to generate the temporally-coherent motion aligned with the music.

Table 4. Ablation study on different fusion strategies for the latent motion representation.


In summary, we have introduced AIOZ-GDANCE, the largest dataset for audio-driven group dance generation. Our dataset contains in-the-wild videos and covers different dance styles and music genres. We then propose a strong baseline along with new evaluation metrics for group dance generation task. We also perform extensive experiments to validate our method on this interesting yet unexplored problem, using our new dataset and evaluation protocols. We hope that the release of our dataset will foster more research on audio-driven group choreography.


[1] Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, and Sepp Hochreiter. Gans trained by a two time-scale update rule converge to a local nash equilibrium. NIPS 2017.

[2] Ruilong Li, Shan Yang, David A Ross, and Angjoo Kanazawa. Ai choreographer: Music-conditioned 3d dance generation with aist++. ICCV 2021

[3] Hsin-Ying Lee, Xiaodong Yang, Ming-Yu Liu, Ting-Chun Wang, Yu-Ding Lu, Ming-Hsuan Yang, and Jan Kautz. Dancing to music. NIPS 2019.

[4] Kensuke Onuma, Christos Faloutsos, and Jessica K Hodgins. Fmdistance: A fast and effective distance function for motion capture data. Eurographics 2008.

[5] Matthew Loper, Naureen Mahmood, Javier Romero, Gerard Pons-Moll, and Michael J. Black. SMPL: A skinned multi-person linear model. ACM Trans. Graphics, 2015

Music-Driven Group Choreography (Part 2)

In the previous post, we have introduced AIOZ-GDANCE, a new largescale in-the-wild dataset for music-driven group dance generation. On the basis of the new dataset, we introduce the first strong baseline for group dance generation that can jointly generate multiple dancing motions expressively and coherently.

Figure 1. The overall architecture of GDanceR. Our model takes in a music sequence and a set of initial positions, and then auto-regressively generates coherent group dance motions that are attuned to the input music.

Music-driven Group Dance Generation Method

Problem Formulation

Given an input music audio sequence {m1,m2,...,mT}\{m_1, m_2, ...,m_T\} with t={1,...,T}t = \{1,..., T\} indicates the index of music segments, and the initial 3D positions of NN dancers {τ01,τ02,...,τ0N}\{\tau^1_0, \tau^2_0, ..., \tau^N_0 \}, τ0iR3\tau^i_0 \in \mathbb{R}^{3}, our goal is to generate the group motion sequences {y11,...,yT1;...;y1n,...,yTn}\{y^1_1,..., y^1_T; ...;y^n_1,...,y^n_T\} where ytiy^i_t is the generated pose of ii-th dancer at time step tt. Specifically, we represent the human pose as a 72-dimensional vector y=[τ;θ]y = [\tau; \theta] where τ\tau, θ\theta represent the root translation and pose parameters of the SMPL model [1], respectively.

In general, the generated group dance motion should meet the two conditions: (i) consistency between the generated dancing motion and the input music in terms of style, rhythm, and beat; (ii) the motions and trajectories of dancers should be coherent without cross-body intersection between dancers. To that end, we propose the first baseline method, for group dance generation that can jointly generate multiple dancing motions expressively and coherently. Figure 1 shows the architecture of our proposed Music-driven 3D Group Dance generatoR (GDanceR), which consists of three main components:

  • Transformer Music Encoder.
  • Initial Pose Generator.
  • Group Motion Generator.

Transformer Music Encoder

From the raw audio signal of the input music, we first extract music features using the available audio processing library Librosa. Concretely, we extract the mel frequency cepstral coefficients (MFCC), MFCC delta, constant-Q chromagram, tempogram, onset strength and one-hot beat, which results in a 438-dimensional feature vector. We then encode the music sequence M={m1,m2,...,mT}M =\{m_1, m_2, ...,m_T\}, mtR438m_t \in \mathbb{R}^{438} into a sequence of hidden representation {a1,a2,...,aT}\{a_1, a_2,..., a_T\}, atRdaa_t \in \mathbb{R}^{d_a}. In practice, we utilize the self-attention mechanism of transformer [2] to effectively encode the multi-scale information and the long-term dependency between music frames. The hidden audio at each time step is expected to contain meaningful structural information to ensure that the generated dancing motion is coherent across the whole sequence.

Specifically, we first embed the music features mtm_t using a Linear layer followed by Positional Encoding to encode the time ordering of the sequence.

U=PE(MWau)U = \text{PE}({M} W^u_a)

where PE\text{PE} denotes the Positional Encoding, and WauR438×daW^u_a \in \mathbb{R}^{438 \times d_a} is the parameters of the linear projection layer. Then, the hidden audio information can be calculated using self-attention mechanism:

A=FF(softmax((UqUk)Tdk)Uv),Uq=UWaq,Uk=UWak,Uv=UWav\\ \mathbb{A} = \text{FF}(\text{softmax}\left(\frac{(U^q U^k)^T}{\sqrt{d_{k}}} \right) U^v ), \\ U^q = U W^q_a, \quad U^k = U W^k_a, \quad U^v = UW^v_a

where Waq,WakRda×dkW^q_a, W^k_a \in \mathbb{R}^{d_a \times d_k}, and WavRda×dvW^v_a \in \mathbb{R}^{d_a \times d_v} are the parameters that transform the linear embedding audio UU into a query UqU^q, a key UkU^k, and a value UvU^v respectively. dad_a is the dimension of the hidden audio representation while dkd_k is the dimension of the query and key, and dvd_v is the dimension of value. FF\text{FF} is a feed-forward neural network.

Initial Pose Generator

Figure 2.The Transformer Music Encoder encodes the acoustic and rhythmic information to generate the initial poses from the input positions.

Given the initial positions of all dancers, we generate the initial poses by combing the audio feature with the starting positions. We aggregate the audio representation by taking an average over the audio sequence. The aggregated audio is then concatenated with the input position and fed to a multilayer perceptron (MLP) to predict the initial pose for each dancer:

y0i=MLP([1Tt=1Tat;τ0i]),y^i_0 = \text{MLP}\left( \left[\frac{1}{T}\sum_{t=1}^T a_t ; \tau^i_0 \right] \right),

where [;][;] is the concatenation operator, τ0i\tau^i_0 is the initial position of the ii-th dancer.

Group Motion Generator

Figure 3. The Group Motion Generator auto-regressively generates coherent group dance motions based on the encoded acoustic information.

To generate the group dance motion, we aim to synthesize the coherent motion of each dancer such that it aligns well with the input music. Furthermore, we also need to maintain global consistency between all dancers. As shown in Figure 3, our Group Generator comprises a Group Encoder to encode the group sequence information and an MLP Decoder to decode the hidden representation back to the human pose space. To effectively extract both the local motion and global information of the group dance sequence through time, we design our Group Encoder based on two factors: Recurrent Neural Network [3] to capture the temporal motion dynamics of each dancer, and Attention mechanism [2] to encode the spatial relationship of all dancers.

Specifically, at each time step, the pose of each dancer in the previous frame yt1iy^i_{t-1} is sent to an LSTM unit to encode the hidden local motion representation htih^i_t:


To ensure the motions of all dancers have global coherency and discourage strange effects such as cross-body intersection, we introduce the Cross-entity Attention mechanism. In particular, each individual motion representation is first linearly projected into a key vector kik^i, a query vector qiq^i and a value vector viv^i as follows: \begin{equation} k^i = h^i W^{k}, \quad q^i = h^i W^{q}, \quad v^i = h^i W^{v}, \end{equation} where Wq,WkRdh×dkW^q, W^k \in \mathbb{R}^{d_h \times d_k}, and WvRdh×dvW^v \in \mathbb{R}^{d_h \times d_v} are parameters that transform the hidden motion hh into a query, a key, and a value, respectively. dkd_k is the dimension of the query and key while dvd_v is the dimension of the value vector. To encode the relationship between dancers in the scene, our Cross-entity Attention also utilizes the Scaled Dot-Product Attention as in the Transformer [3].

Figure 4. The Group Encoder learns to encode the relations among dancers through our proposed Cross-entity Attention mechanism.

In practice, we find that people having closer positions to each other tend to have higher correlation in their movement. Therefore, we adopt Spacial Encoding strategy to encode the spacial relationship between each pair of dancers. The Spacial Encoding between two entities based on their distance in the 3D space is defined as follows:

eij=exp(τiτj2dτ),e_{ij} = \exp\left(-\frac{\Vert \tau^i - \tau^j \Vert^2}{\sqrt{d_{\tau}}}\right),

where dτd_{\tau} is the dimension of the position vector τ\tau. Considering the query qiq^i, which represents the current entity information, and the key kjk^j, which represents other entity information, we inject the spatial relation information between these two entities onto their cross attention coefficient:

αij=softmax((qi)kjdk+eij).\alpha_{ij} = \text{softmax}\left(\frac{(q^i)^\top k^j}{\sqrt{d_k}} + e_{ij}\right).

To preserve the spatial relative information in the attentive representation, we also embed them into the hidden value vector and obtain the global-aware representation gig^i of the ii-th entity as follows:

gi=j=1Nαij(vj+eijγ),g^i = \sum_{j=1}^N\alpha_{ij}(v^j + e_{ij}\gamma),

where γRdv\gamma \in \mathbb{R}^{d_v} is the learnable bias and scaled by the Spacial Encoding. Intuitively, the Spacial Encoding acts as the bias in the attention weight, encouraging the interactivity and awareness to be higher between closer entities. Our attention mechanism can adaptively attend to each dancer and others in both temporal and spatial manner, thanks to the encoded motion as well as the spatial information.

We then fuse both the local and global motion representation by adding hih^i and gig^i to obtain the final latent motion ziz^i. Our final global-local representation of each entity is expected to carry the comprehensive information of their own past motion as well as the motion of every other entity, enabling the MLP Decoder to generate coherent group dancing sequences. Finally, we generate the next movement yti{y}^i_t based on the final motion representation ztiz^i_t as well as the hidden audio representation ata_t, and thus can capture the fine-grained correspondence between music feature sequence and dance movement sequence:

yti=MLP([zti;at]).y^i_t = \text{MLP}([z^i_t; a_t]).

Built upon these components, our model can effectively learn and generate coherent group dance animation given several pieces of music. In the next part, we will go through the experiments and detailed studies of the method.


[1] Matthew Loper, Naureen Mahmood, Javier Romero, Gerard Pons-Moll, and Michael J. Black. SMPL: A skinned multiperson linear model. ACM Trans. Graphics, 2015

[2] Vaswani A, Shazeer N, Parmar N, Uszkoreit J, Jones L, Gomez AN, Kaiser Ł, Polosukhin I. Attention is all you need. NIPS 2017.

[3] Hochreiter S, Schmidhuber J. Long short-term memory. Neural computation. 1997 Nov 15;9(8):1735-80.

Music-Driven Group Choreography (Part 1)

Dancing is an important part of human culture and remains one of the most expressive physical art and communication forms. With the rapid development of digital social media platforms, creating dancing videos has gained significant attention from social communities. As a result, millions of dancing videos are created and watched daily on online platforms. Recently, studies of how to create natural dancing motion from music have attracted great attention in the research community.

Figure 1. We demonstrate the AIOZ-GDANCE dataset with in-the-wild videos, music audio, and 3D group dance motion.

Nevertheless, generating dance motion for a group of dancers remains an open problem and have not been well-investigated by the community yet. Motivated by these shortcomings and to foster research on group choreography, we establish AIOZ-GDANCE, a new largescale in-the-wild dataset for music-driven group dance generation. Unlike existing datasets that only support single dance, our new dataset contains group dance videos as shown in Figure 1, hence supporting the study of group choreography. On the basis of the new dataset, we propose the first strong baseline for group dance generation that can jointly generate multiple dancing motions expressively and coherently.

Dataset Construction

Figure 2. The pipeline of making our AIOZ-GDANCE dataset.

In this section, we will elaborate and describe the process to build our dataset from a large variety of videos available on the internet. Because our main goal is to develop a large-scale dataset with in-the-wild videos, setting up a MoCap system as in many classical approaches is not feasible. However, manually creating 3D groundtruth for millions of frames from dancing videos is also an extremely costly and tedious job. To that end, we propose a semi-automatic labeling method with humans in the loop to obtain the 3D ground truth for our dataset. The process to construct the data includes the five following key steps:

  1. Video collection
  2. Human Tracking
  3. Human Pose Estimation
  4. Local Fitting for Invidual Motions
  5. Global Scene Optimization

Data Collection and Preprocessing

Figure 3. Human Tracking.

Video Collection. We collect the in-the-wild, public domain group dancing videos along with the music from Youtube, Tiktok, and Facebook. All group dance videos are processed at 1920 × 1080 resolution and 30FPS.

Human Tracking. We perform tracking for all humans in the videos using the state-of-the-art multi-object tracker [1] to obtain the tracking bounding boxes. Note that although the tracker can produce reasonable results, there are failure cases in some frames. Therefore, we manually correct the bounding box of the incorrect cases. This tracking correction is crucial since we want the trajectory of each person to be accurately tracked in order to reconstruct their motion in latter stages.

Pose Estimation. Given the bounding boxes of each person in the video, we leverage a state-of-the-art 2D pose estimation method [2] to generate the initial 2D poses for each person. In practice, there exist some inaccurately detected keypoints due to motion blur and partial occlusion. We manually fix the incorrect cases to obtain the 2D keypoints of each human bounding box.

Local Mesh Fitting

Figure 4. Local Mesh Fitting.

To construct 3D group dance motion, we first reconstruct the full body motion for each dancer by fitting the 3D mesh. We then jointly optimize all dancer motions to construct the globally-coherent group motion. Finally, we post-process and remove wrong cases from the optimization results.

We use SMPL model [3] to represent the 3D human. The SMPL model is a differentiable function that maps the pose parameters θ\mathbf{\theta}, the shape parameters β\mathbf{\beta}, and the root translation τ\mathbf{\tau} into a set of 3D human body mesh vertices VR6890×3\mathbf{V}\in \mathbb{R}^{6890\times3} and 3D joints XRJ×3\mathbf{X}\in \mathbb{R}^{J\times3}, where JJ is the number of body joints.

Our optimizing motion variables for each individual dancer consist of a sequence of SMPL joint angles {θt}t=1T\{\mathbf{\theta}_t\}_{t=1}^T, a sequence of the root translation {τt}t=1T\{\mathbf{\tau}_t\}_{t=1}^T, and a single SMPL shape parameter β\mathbf{\beta}. We fit the sequence of SMPL motion variables to the tracked 2D keypoints by extending SMPLify-X framework [4] across the whole video sequence:

Elocal=EJ+λθEθ+λβEβ+λSES+λFEFE_{\rm local} = E_{\rm J} + \lambda_{\theta}E_{\theta} + \lambda_{\beta} E_{\beta} + \lambda_{\rm S}E_{\rm S} + \lambda_{\rm F}E_{\rm F}


  • EJE_{\rm J} is the 2D reprojection term between the 2D keypoints and the 2D projection of the corresponding 3D poses.
  • EθE_{\theta} is the pose prior term from the latent space of the VPoser model [4] to encourage plausible human pose.
  • EβE_{\beta} is the shape prior term to regularize the body shape towards the mean shape of the SMPL body model.
  • ES=t=1T1θt+1θt2+j=1Jt=1T1Xj,t+1Xj,t2E_{\rm S} = \sum_{t=1}^{T-1}\Vert \mathbf{\theta}_{t+1} - \mathbf{\theta}_{t} \Vert^2 + \sum_{j=1}^J\sum_{t=1}^{T-1}\Vert \mathbf{X}_{j,t+1} - \mathbf{X}_{j,t} \Vert^2 is the smoothness term to encourage the temporal smoothness of the motion.
  • EF=t=1T1jFcj,tXj,t+1Xj,t2E_{\rm F} = \sum_{t=1}^{T-1} \sum_{j \in \mathcal{F}} c_{j,t}\Vert \mathbf{X}_{j,t+1} - \mathbf{X}_{j,t} \Vert^2 is to ensure feet joints to stay stationary when in contact (zero velocity). Where F\mathcal{F} is the set of feet joint indexes, cj,tc_{j,t} is the feet contact of joint jj at time tt.

Global Optimization

Figure 5. Global Optimization.

Given the 3D motion sequence of each dancer pp: {θtp,τtp}\{\mathbf{\theta}^p_t, \mathbf{\tau}^p_t\}, we further resolve the motion trajectory problems in group dance by solving the following objective:

Eglobal=EJ+λpenEpen+λregpEreg(p)+λdepp,p,tEdep(p,p,t)+λgcpEgc(p)E_{\rm global} = E_{\rm J} + \lambda_{\rm pen}E_{\rm pen} + \lambda_{\rm reg}\sum_{p}E_{\rm reg}(p) + \lambda_{\rm dep}\sum_{p,p',t}E_{\rm dep}(p,p',t) + \lambda_{\rm gc}\sum_{p}E_{\rm gc}(p)

EpenE_{\rm pen} is the Signed Distance Function penetration term to prevent the overlapping of reconstructed motions between dancers.

Ereg(p)=t=1Tθtpθ^tp2{E_{\rm reg}(p) =\sum_{t=1}^T\Vert \mathbf{\theta}^p_t - \hat{\mathbf{\theta}}^p_t\Vert^2} is the regularization term that prevents the motion from deviating too much from the prior optimized individual motion {θ^tp}\{\hat{\mathbf{\theta}}^p_t\} obtained by optimizing the local mesh for dancer pp.

In practice, we find that the relative depth ordering of dancers in the scene can be inconsistent due to the ambiguity of the 2D projection. To ensure the group motion quality, we watch the videos and manually provide the ordinal depth relation information of all dancers in the scene at each frame tt as follows:

rt(p,p)={1,if dancer p is closer than p1,if dancer p is farther than p0,if their depths are roughly equalr_t(p,p') = \begin{cases} 1, &\text{if dancer } p \text{ is closer than } p' \\ -1, &\text{if dancer } p \text{ is farther than } p' \\ 0, &\text{if their depths are roughly equal} \end{cases}

Given the relative depth information, we derive the depth relation term EdepE_{\rm dep}. This term encourages consistent ordinal depth relation between the motion trajectories of multiple dancers, especially when dancers partially occlude each other:

Edep(p,p,t)={log(1+exp(ztpztp)),rt(p,p)=1log(1+exp(ztp+ztp)),rt(p,p)=1(ztpztp)2,rt(p,p)=0E_{\rm dep}(p,p',t) = \begin{cases} \log(1+\exp(z^p_t - z^{p'}_t)), &r_t(p,p')=1 \\ \log(1+\exp(-z^p_t + z^{p'}_t)), &r_t(p,p')=-1 \\ (z^p_t - z^{p'}_t)^2, &r_t(p,p')=0 \\ \end{cases}

where ztpz^p_t is the depth component of the root translation τtp\mathbf{\tau}^p_t of the person pp at frame tt. Intuitively, for r(p,p)=1r(p,p')=1, zpz_p should be smaller than zpz_{p'} and otherwise.

Finally, we apply the global ground contact constraint EgcE_{\rm gc} to further ensure consistency between the motion of every person and the environment based on the ground contact information. This contact term is also needed to reduce the artifacts such as foot-skating, jittering, and penetration under the ground.

Egc(p)=t=1T1jFcj,tpXj,t+1pXj,tp2+cj,tp(Xj,tpf)n2E_{\rm gc}(p) = \sum_{t=1}^{T-1} \sum_{j \in \mathcal{F}} c^p_{j,t}\Vert \mathbf{X}^p_{j,t+1} - \mathbf{X}^p_{j,t} \Vert^2 + c^p_{j,t} \Vert (\mathbf{X}^p_{j,t} - \mathbf{f})^\top \mathbf{n}^* \Vert^2

where F\mathcal{F} is the set of feet joint indexes, n\mathbf{n}^* is the estimated plane normal and f\mathbf{f} is a 3D fixed point on the ground plane. The first term in Equation~\ref{eq_Egc} is the zero velocity constraint when the feet are in contact with the ground, while the second term encourages the feet position to stay near the ground when in contact. To obtain the ground plane parameters, we initialize the plane point f\mathbf{f} as the weighted median of all contact feet positions. The plane normal n\mathbf{n}^* is obtained by optimizing a robust Huber objective:

n=argminnXfeetH((Xfeetf)nn)+nn12,\mathbf{n}^* = \arg\min_{\mathbf{n}} \sum_{\mathbf{X}_{\rm feet}} \mathcal{H}\left((\mathbf{X}_{\rm feet} - \mathbf{f})^\top \frac{\mathbf{n}}{\Vert\mathbf{n}\Vert}\right) + \Vert \mathbf{n}^\top\mathbf{n} - 1 \Vert^2,

where H\mathcal{H} is the Huber loss function, Xfeet\mathbf{X}_{\rm feet} is the 3D feet positions of all dancers across the whole sequence that are labelled as in contact (i.e., cj,tp=1c^p_{j,t} = 1) .

How will AIOZ-GDANCE be useful to the community?

We bring up some interesting research directions that can be benefited from our dataset:

  • Group Dance Generation
  • Human Pose Tracking
  • Dance Education
  • Dance style transfer
  • Human behavior analysis

While single-person choreography is a hot research topic recently, group dance generation has not yet well investigated. We hope that the release of our dataset will foster more this research direction.


[1] Peize Sun, Jinkun Cao, Yi Jiang, Zehuan Yuan, Song Bai, Kris Kitani, and Ping Luo. Dancetrack: Multi-object tracking in uniform appearance and diverse motion. In CVPR, 2022

[2] Hao-Shu Fang, Shuqin Xie, Yu-Wing Tai, and Cewu Lu. RMPE: Regional multi-person pose estimation. In ICCV, 2017.

[3] Matthew Loper, Naureen Mahmood, Javier Romero, Gerard Pons-Moll, and Michael J. Black. SMPL: A skinned multiperson linear model. ACM Trans. Graphics, 2015

[4] Georgios Pavlakos, Vasileios Choutas, Nima Ghorbani, Timo Bolkart, Ahmed A. A. Osman, Dimitrios Tzionas, and Michael J. Black. Expressive body capture: 3d hands, face, and body from a single image. In CVPR, 2019.

Uncertainty-aware Label Distribution Learning for Facial Expression Recognition (Part 1)

Facial expression recognition (FER) plays an important role in understanding people's feelings and interactions between humans. Recently, automatic emotion recognition has gained a lot of attention from the research community due to its tremendous applications in education, healthcare, human analysis, surveillance or human-robot interaction. Recent FER methods are mostly based on deep learning and can achieve impressive results. The success of deep models can be attributed to large-scale FER datasets [1][2]. However, ambiguities of facial expression is still a key challenge in FER. Specifically, people with different backgrounds might perceive and interpret facial expressions differently, which can lead to noisy and inconsistent annotations. In addition, real-life facial expressions usually manifest a mixture of feelings rather than only a single emotion.

Motivation and Proposed Solution

Figure 1. Examples of real-world ambiguous facial expressions that can lead to noisy and inconsistent annotation.

As an example, Figure 1 shows that people may have different opinions about the expressed emotion, particularly in ambiguous images. Consequently, a distribution over emotion categories is better than a single label because it takes all sentiment classes into account and can cover various interpretations, thus mitigating the effect of ambiguity. However, existing large-scale FER datasets only provide a single label for each sample instead of a label distribution, which means we do not have a comprehensive description for each facial expression. This can lead to insufficient supervision during training and pose a big challenge for many FER systems.

To overcome the ambiguity problem in FER, we proposes a new uncertainty-aware label distribution learning method that constructs emotion distributions for training samples. Specifically, we leverage the neighborhood information of samples that have similar expressions to construct the emotion distributions from single labels and utilize them as training supervision signal.



We denote xX\mathbf{x} \in \mathcal{X} as the instance variable in the input space X\mathcal{X} and xi\mathbf{x}^{i} as the particular ii-th instance. The label set is denoted as Y={y1,y2,...,ym}\mathcal{Y} = \{y_1, y_2,..., y_m\} where mm is the number of classes and yjy_j is the label value of the jj-th class. The logical label vector of xi\mathbf{x}^{i} is indicated by li\mathbf{l}^{i} = (ly1i,ly2i,...,lymi)(l^{i}_{y_1}, l^{i}_{y_2}, ..., l^{i}_{y_m}) with lyji{0,1}\mathbf{l}^{i}_{y_j} \in \{0, 1\} and l1=1\| \mathbf{l} \| _1 = 1. We define the label distribution of xi\mathbf{x}^{i} as di\mathbf{d}^{i} = (dy1i,dy2i,...,dymi)(d^{i}_{y_1}, d^{i}_{y_2}, ..., d^{i}_{y_m}) with d1=1\| \mathbf{d} \| _1 = 1 and dyji[0,1]d^{i}_{y_j} \in [0, 1] representing the relative degree that xi\mathbf{x}^{i} belongs to the class yjy_j.

Most existing FER datasets assign only a single class or equivalently, a logical label li\mathbf{l}^{i} for each training sample xi\mathbf{x}^{i}. In particular, the given training dataset is a collection of nn samples with logical labels DlD_l = {(xi,li)1in}\{ (\mathbf{x}^{i}, \mathbf{l}^{i}) \vert 1 \le i \le n\}. However, we find that a label distribution di\mathbf{d}^i is a more comprehensive and suitable annotation for the image than a single label.

Inspired by the recent success of label distribution learning (LDL) in addressing label ambiguity [3], we aim to construct an emotion distribution di\mathbf{d}^i for each training sample xi\mathbf{x}^i, thus transform the training set DlD_l into DdD_d = {(xi,di)1in}\{ (\mathbf{x}^{i}, \mathbf{d}^{i}) \vert 1 \le i \le n\}, which can provide richer supervision information and help mitigate the ambiguity issue. We use cross-entropy to measure the discrepancy between the model's prediction and the constructed target distribution. Hence, the model can be trained by minimizing the following classification loss:

Lcls=i=1nCE(di,f(xi;θ))=i=1nj=1mdjilogfj(xi;θ).\mathcal{L}_{cls} = \sum_{i=1}^n \text{CE}\left(\mathbf{d}^i, f(\mathbf{x}^i; \theta)\right) = -\sum_{i=1}^n \sum_{j=1}^m \mathbf{d}_j^{i} \log f_j(\mathbf{x}^{i};\theta).

where f(x;θ)f(\mathbf{x}; \theta) is a neural network with parameters θ\theta followed by a softmax layer to map the input image x\mathbf{x} into a emotion distribution.


Figure 2. An overview of our Label Distribution Learning with Valence-Arousal (LDLVA) for facial expression recognition under ambiguity.

An overview of our method is presented in Figure 2. To construct the label distribution for each training instance xi\mathbf{x}^i, we leverage its neighborhood information in the valence-arousal space. Particularly, we identify KK neighbor instances for each training sample xi\mathbf{x}^i and utilize our adaptive similarity mechanism to determine their contribution degrees to the target distribution di\mathbf{d}^i. Then, we combine the neighbors' predictions and their corresponding contribution degrees with the provided label li\mathbf{l}^i and li\mathbf{l}^i's uncertainty factor to obtain the label distribution di\mathbf{d}^i. The constructed distribution di\mathbf{d}^i will be used as supervision information to train the model via label distribution learning.

Adaptive Similarity

We assume that the label distribution of the main instance xi\mathbf{x}^i can be computed as a linear combination of its neighbors' distributions. To determine the contribution of each neighbor, we propose an adaptive similarity mechanism that not only leverages the relationships between xi\mathbf{x}^i and its neighbors in the auxiliary space but also utilizes their feature vectors extracted from the backbone. We choose the valence-arousal [4] as the auxiliary space to construct the target label distribution. We use the KK-Nearest Neighbor algorithm to identify KK closest points for each training sample xi\mathbf{x}^i, denoted as N(i)N(i). We calculate the adaptive contribution degrees of neighbor instances as the product of the local similarity skis^i_k and the calibration score ζki\zeta^i_k as follows:

cki={ζkiski,for xkN(i),0,otherwise.c^i_k = \begin{cases} \zeta^i_k s^i_k, &\text{for } \mathbf{x}^k \in N(i), \\ 0, &\text{otherwise}. \end{cases}

where the local similarity skis^i_k is defined based on the distance between the instance and its neighbor in the valence-arousal space ai\mathbf{a}^i and ak\mathbf{a}^k

ski=exp(aiak22δ2),xkN(i)s^i_k = \exp\left(-\frac{\| \mathbf{a}^i - \mathbf{a}^k \|^2_2}{\delta^2}\right), \quad \forall \mathbf{x}^k \in N(i)

We utilize a multilayer perceptron (MLP) gg with parameter ϕ\phi to calculate the adaptive calibration score from the extracted features of the two instances vi\mathbf{v}^i and vk\mathbf{v}^k obtained from the backbone.

ζki=Sigmoid(g([vi,vk];ϕ))\zeta^i_k = Sigmoid\left(g([\mathbf{v}^i,\mathbf{v}^{k}];\phi)\right)

The proposed adaptive similarity can correct the similarity errors in the valance-arousal space, as the valence-arousal values are not always available in practice and we leverage an existing method to generate pseudo-valence-arousal.

Uncertainty-aware Label Distribution Construction

After obtaining the contribution degree of each neighbor xkN(i)\mathbf{x}^k \in N(i), we can now generate the target label distribution di\mathbf{d}^i for the main instance xi\mathbf{x}^i. The target label distribution is calculated using the logical label li\mathbf{l}^i and the aggregated distribution d~i\tilde{\mathbf{d}}^i defined as follows:

di~=kckif(xk;θ)kcki,di=(1λi)li+λidi~\tilde{\mathbf{d}^i} = \frac{\sum_k c^i_k f(\mathbf{x}^{k};\theta)}{\sum_k c^i_k}, \\ \mathbf{d}^i = (1-\lambda^i) \mathbf{l}^i + \lambda^i \tilde{\mathbf{d}^i}

where λi[0,1]\lambda^i \in [0,1] is the uncertainty factor for the logical label. It controls the balance between the provided label li\mathbf{l}^i and the aggregated distribution di~\tilde{\mathbf{d}^i} from the local neighborhood.

Intuitively, a high value of λi\lambda^i indicates that the logical label is highly uncertain, which can be caused by ambiguous expression or low-quality input images, thus we should put more weight towards neighborhood information di~\tilde{\mathbf{d}^i}. Conversely, when λi\lambda^i is small, the label distribution di\mathbf{d}^i should be close to li\mathbf{l}^i since we are certain about the provided manual label. In our implementation, λi\lambda^i is a trainable parameter for each instance and will be optimized jointly with the model's parameters using gradient descent.

Loss Function

To enhance the model's ability to discriminate between ambiguous emotions, we also propose a discriminative loss to reduce the intra-class variations of the learned facial representations. We incorporate the label uncertainty factor λi\lambda^i to adaptively penalize the distance between the sample and its corresponding class center. For instances with high uncertainty, the network can effectively tolerate their features in the optimization process. Furthermore, we also add pairwise distances between class centers to encourage large margins between different classes, thus enhancing the discriminative power. Our discriminative loss is calculated as follows:

LD=12i=1n(1λi)viμyi22+j=1mk=1kjmexp(μjμk22V)\mathcal{L}_D = \frac{1}{2}\sum_{i=1}^n (1-\lambda^i)\Vert \mathbf{v}^i - \mathbf{\mu}_{y^i} \Vert_2^2 + \sum_{j=1}^m \sum_{\substack{k=1 \\ k \neq j}}^m \exp \left(-\frac{\Vert\mathbf{\mu}_{j}-\mathbf{\mu}_{k}\Vert_2^2}{\sqrt{V}}\right)

where yiy^i is the class index of the ii-th sample while μj\mathbf{\mu}_{j}, μk\mathbf{\mu}_{k}, and μyi\mathbf{\mu}_{y^i} RV\in \mathbb{R}^V are the center vectors of the j{j}-th, k{k}-th, and yiy^i-th classes, respectively. Intuitively, the first term of LD\mathcal{L}_D encourages the feature vectors of one class to be close to their corresponding center while the second term improves the inter-class discrimination by pushing the cluster centers far away from each other. Finally, the total loss for training is computed as:

L=Lcls+γLD\mathcal{L} = \mathcal{L}_{cls} + \gamma\mathcal{L}_D

where γ\gamma is the balancing coefficient between the two losses.


[1] Ali Mollahosseini, Behzad Hasani, and Mohammad H. Mahoor. Affectnet: A database for facial expression, valence, and arousal computing in the wild. IEEE Transactions on Affective Computing, 2019

[2] Shan Li, Weihong Deng, and JunPing Du. Reliable crowdsourcing and deep locality-preserving learning for expression recognition in the wild. In CVPR, 2017.

[3] B. Gao, C. Xing, C. Xie, J. Wu, and X. Geng. Deep label distribution learning with label ambiguity. IEEE Transactions on Image Processing, 2017.

Uncertainty-aware Label Distribution Learning for Facial Expression Recognition (Part 2)

In the previous post, we have introduced the our proposed method for Facial Expression Recognition. In this post, we will examine the effectiveness and efficiency of the proposal.

Experimental Results

Noisy and Inconsistent Labels

Table 1. Test performance with synthetic noise.

We conduct experiments to study the robustness of our LDLVA on mislabelled data by adding synthetic noise to AffectNet, RAF-DB, and SFEW datasets. Specifically, we randomly flip the manual labels to one of the other categories. . We report the mean accuracy and standard error in Table 1. The results clearly show that our method consistently outperforms other approaches in all cases. We also observe that the improvements are even more apparent when the noise ratio increases, for example, the accuracy improvement on RAF-DB is 4.7\% with 10\% noise and 6.93\% with 30\% noise. The consistent results under various settings demonstrate the ability of our method to effectively deal with noisy annotation, which is crucial in the robustness against label ambiguity.

Table 2. Test performance with inconsistent labels between cross-datasets.

Since the annotations for large-scale FER data are commonly obtained via crowd-sourcing, this can create label inconsistency, especially between different datasets. To examine the effectiveness of our proposed methods in dealing with this problem, we also perform experiment with the cross-dataset protocol. Table 2 shows that our method achieves the best performance on all three datasets and the highest average accuracy and surpasses the current state-of-the-art methods. This confirms the advantages of our method over previous works and demonstrates the generalization ability to data with label inconsistency, which is essential for real-world FER applications.

Comparison with state of the arts

Table 3. Comparison with recent methods on the original datasets.

We further compare our method with several state of the arts on the original AffectNet, RAF-DB, and SFEW to evaluate the robustness of our method to the uncertainty and ambiguity that unavoidably exists in real-world FER datasets. The results are presented in Table 3. By leveraging label distribution learning on valence-arousal space, our model outperforms other methods and achieves state-of-the-art performance on AffectNet, RAF-DB, and SFEW. Although these datasets are considered to be "clean", the results suggest that they indeed suffer from uncertainty and ambiguity.

Qualitative Analysis

Real-world Ambiguity: To understand more about real-world ambiguous expressions, we conducted a user study in which we asked participants to choose the most clearly expressed emotion on random test images. We compare our model's predictions with the survey results in Figure 3. We can see that these images are ambiguous as they express a combination of different emotions, hence the participants do not fully agree and have different opinions about the most prominent emotion on the faces. It is further shown that our model can give consistent results and agree with the perception of humans to some degree.

Figure 3. Comparison of the results from our user study and our model.

Uncertainty Factor: Figure 4 shows the estimated uncertainty factors of some training images and their original labels. The uncertainty values decrease from top to bottom. Highly uncertain labels can be caused by low-quality inputs (as shown in Angry and Surprise columns) or ambiguous facial expressions. In contrast, when the emotions can be easily recognized as those in the last row, the uncertainty factors are assigned low values. This characteristic can guide the model to decide whether to put more weight on the provided label or the neighborhood information. Therefore, the model can be more robust against uncertainty and ambiguity.

Figure 4. Visualization of uncertainty values of some examples from RAF-DB dataset.


We have introduced a new label distribution learning method for facial expression recognition by leveraging structure information in the valence-arousal space to recover the intensities distributed over emotion categories. The constructed label distribution provides rich information about the emotions, thus can effectively describe the ambiguity degree of the facial image. Intensive experiments on popular datasets demonstrate the effectiveness of our method over previous approaches under inconsistency and uncertainty conditions in facial expression recognition.


[1] Ali Mollahosseini, Behzad Hasani, and Mohammad H. Mahoor. Affectnet: A database for facial expression, valence, and arousal computing in the wild. IEEE Transactions on Affective Computing, 2019

[2] Shan Li, Weihong Deng, and JunPing Du. Reliable crowdsourcing and deep locality-preserving learning for expression recognition in the wild. In CVPR, 2017.

[3] B. Gao, C. Xing, C. Xie, J. Wu, and X. Geng. Deep label distribution learning with label ambiguity. IEEE Transactions on Image Processing, 2017.

Light-weight Deformable Registration using Adversarial Learning with Distilling Knowledge (Part 3)

In this part, we will show the effectivness and the ablation studies of Light-weight Deformable Registration Network and Adversarial Learning Algorithm with Distilling Knowledge.


As mentioned in [1], we train method on two types of scans: Liver CT scans and Brain MRI scans.

For Liver CT scans, we use 5 datasets:

  1. LiTS contains 131 liver segmentation scans.
  2. MSD has 70 liver tumor CT scans, 443 hepatic vessels scans, and 420 pancreatic tumor scans.
  3. BFH is a smaller dataset with 92 scans.
  4. SLIVER is a challenging dataset with 20 liver segmentation scans and annotated by 3 expert doctors.
  5. LSPIG (Liver Segmentation of Pigs) contains 17 pairs of CT scans from pigs, provided by the First Affiliated Hospital of Harbin Medical University.

For Brain MRI scans, we use 4 datasets: 1. ADNI contains 66 scans. 2. ABIDE contains 1287 scans. 3. ADHD contains 949 scans. 4. LPBA has 40 scans, each featuring a segmentation ground truth of 56 anatomical structures.


We compare LDR ALDK method with the following recent deformable registration methods:

  • ANTs SyN and Elastix B-spline are methods that find an optimal transformation by iteratively update the parameters of the defined alignment.
  • VoxelMorph predicts a dense deformation in an unsupervised manner by using deconvolutional layers.
  • VTN is an end-to-end learning framework that uses convolutional neural networks to register 3D medical images, especially large displaced ones.
  • RCN is a recent recursive deep architecture that utilizes learnable cascade and performs progressive deformation for each warped image.


Table 1 summarizes the overall performance, testing speed, and the number of parameters compared with recent state-of-the-art methods in the deformable registration task. The results clearly show that Light-weight Deformable Registration network (LDR) accompanied by Adversarial Learning with Distilling Knowledge (ALDK) algorithm significantly reduces the inference time and the number of parameters during the inference phase. Moreover, the method achieves competitive accuracy with the most recent highly performed but expensive networks, such as VTN or VoxelMorph. We notice that this improvement is consistent across all experiments on different datasets SLIVER, LiTS, LSPIG, and LPBA.

In particular, we observe that on the SLIVER dataset the Dice score of best model with 3 cascades (3-cas LDR + ALDK) is 0.3% less than the best result of 3-cas VTN + Affine, while inference speed is ?21 times faster on a CPU and the parameters used during inference is ~8 times smaller. Including benchmarking results in three other datasets, i.e., LiTS, LSPIG, and LPBA, light-weight model only trades off an average of 0.5% in Dice score and 1.25% in Jacc score for a significant gain of speed and a massive reduction in the number of parameters. We also notice that method is the only work that achieves the inference time of approximately 1s on a CPU. This makes method well suitable for deployment as it does not require expensive GPU hardware for inference.



Ablation Study

Effectiveness of ALDK. Table 2 summarizes the effectiveness of Adversarial Learning with Distilling Knowledge (ALDK) when being integrated into the light-weight student network. Note that LDR without ALDK is trained using only the reconstruction loss in an unsupervised learning setup. From this table, we clearly see that ALDK algorithm improves the Dice score of the LDR tested in the SLIVER dataset by 3.4%, 4.0%, and 3.1% for 1-cas, 2-cas, and 3-cas setups, respectively. Additionally, using ALDK also increases the Jacc score by 5.2%, 4.9%, and 3.9% for 1-cas LDR, 2-cas LDR, and 3-cas LDR. These results verify the stability of adversarial learning algorithm in the inference phase, under the differences evaluation metrics, as well as the number of cascades setups. Furthermore, Table 2 also clearly shows the effectiveness and generalization of ALDK when being applied to the student network. Since the deformations extracted from the teacher are used only in the training period, adversarial learning algorithm fully maintains the speed and the number of parameters for the light-weight student network during inference. All results indicate that student network incorporated with the adversarial learning algorithm successfully achieves the performance goal, while maintaining the efficient computational cost of the light-weight setup.



Accuracy vs. Complexity. Figure 1 demonstrates the experimental results from the SLIVER dataset between LDR + ALDK and the baseline VTN under multiple recursive cascades setup on both CPU and GPU. On the CPU (Figure 1-a), in terms of the 1-cascade setup, the Dice score of method is 0.2% less than VTN while the speed is ~15 times faster. The more the number of cascades is leveraged, the higher the speed gap between LDR + ALDK and the baseline VTN, e.g. the CPU speed gap is increased to ~21 times in a 3-cascades setup. We also observe the same effect on GPU (Figure 1-b), where method achieves slightly lower accuracy results than VTN, while clearly reducing the inference time. These results indicate that LDR + ALDK can work well with the teacher network to improve the accuracy while significantly reducing the inference time on both CPU and GPU in comparison with the baseline VTN network.


Figure 1:Plots of Dice score and Inference speed with respect to the number of cascades of the baseline Affine + VTN and LDR + ALDK. (a) for CPU speed and (b) for GPU speed. Note that results are reported for the SLIVER dataset; bars represent the CPU speed; lines represent the Dice score. All methods use an Intel Xeon E5-2690 v4 CPU and Nvidia GeForce GTX 1080 Ti GPU for inference.


Figure 2 illustrates the visual comparison among 1-cas LDR, 1-cas LDR + ALDK, and the baseline 1-cas RCN. Five different moving images in a volume are selected to apply the registration to a chosen fixed image. It is important to note that though the sections of the warped segmentations can be less overlap with those of the fixed one, the segmentation intersection over union is computed for the volume and not the sections. In the segmented images in Figure 2, besides the matched area colored by white, we also marked the miss-matched areas by red for an easy-to-read purpose.

From Figure 2, we can see that the segmentation resutls of 1-cas LDR network without using ALDK (Figure 2-a) contains many miss-matched areas (denoted in red color). However, when we apply ALDK to the student network, the registration results are clearly improved (Figure 2-b). Overall, LDR + ALDK visualization results in Figure 2-b are competitive with the baseline RCN network (Figure 2-c). This visualization confirms that framework for deformable registration can achieve comparable results with the recent RCN network.


Figure 2:The visualization comparison between LDR (a), LDR + ALDK (b), and the baseline RCN (c). The left images are sections of the warped images; the right images are sections of the warped segmentation (white color represents the matched areas between warped image and fixed image, red color denotes the miss-matched areas). The segmentation visualization indicates that LDR + ALDK (b) method reduces the miss-matched areas of the student network LDR (a) significantly. Best viewed in color.


[1] Tran, Minh Q., et al. "Light-weight deformable registration using adversarial learning with distilling knowledge." IEEE Transactions on Medical Imaging, 2022.

Open Source

🐱 Github: https://github.com/aioz-ai/LDR_ALDK

Light-weight Deformable Registration using Adversarial Learning with Distilling Knowledge (Part 2)

In this part, we will introduce the Architecture of Light-weight Deformable Registration Network and Adversarial Learning Algorithm with Distilling Knowledge.

The Architecture of Light-weight Deformable Registration Network

In practice, recent deformation networks follow an encoder-decoder architecture and use 3D convolution to progressively down-sample the image, and deconvolution (transposed convolution) to recover spatial resolution [1, 3]. However, this setup consumes a large number of parameters. Therefore, the built models are computationally expensive and time-consuming. To overcome this problem we design a new light-weight student network as illustrated in Figure 1.

In particular, the proposed light-weight network has four convolution layers and three deconvolution layers. Each convolutional layer has a bank of 4×4×44 \times 4 \times 4 filters with strides of 2×2×22 \times 2 \times 2, followed by a ReLU activation function. The number of output channels of the convolutional layers starts with 1616 at the first layer, doubling at each subsequent layer, and ends up with 128128. Skip connections between the convolutional layers and the deconvolutional layers are added to help refine the dense prediction. The subnetwork outputs a dense flow prediction field, i.e., a 33 channels volume feature map with the same size as the input.

In comparison with the current state-of-the-art dense deformable registration network [3], the number of parameters of our proposed light-weight student network is reduced approximately 1010 times. In practice, this significant reduction may lead to an accuracy drop. Therefore, we propose a new Adversarial Learning with Distilling Knowledge algorithm to effectively leverage the teacher deformations ϕt\phi_t to our introduced student network, making it light-weight but achieving competitive performance.


Figure 1: The structure of Light-weight Deformable Registration student network. The number of channels is annotated above the layer. Curved arrows represent skip paths (layers connected by an arrow are concatenated before transposed convolution). Smaller canvas means lower spatial resolution (Source).

Adversarial Learning Algorithm with Distilling Knowledge

Our adversarial learning algorithm aims to improve the student network accuracy through the distilled teacher deformations extracted from the teacher network. The learning method comprises a deformation-based adversarial loss Ladv\mathcal{L}_{adv} and its accompanying learning strategy (Algorithm 1).


Figure 2: Adversarial Learning Strategy(Source).

Adversarial Loss. The loss function for the light-weight student network is a combination of the discrimination loss ldisl_{dis} and the reconstruction loss lresl_{res}. However, the forward and backward process through loss function is controlled by the Algorithm 1. In particular, the last deformation loss Ladv\mathcal{L}_{adv} that outputs the final warped image can be written as:

Ladv=γlrec+(1γ)ldis\mathcal{L}_{adv} = \gamma l_{rec} + (1 - \gamma) l_{dis}

where γ\gamma controls the contribution between lrecl_{rec} and ldisl_{dis}. Note that, the Ladv\mathcal{L}_{adv} is only applied for the final warped image.

Discrimination Loss. In the student network the discrimination loss is computed in Equation below}.

ldis=Dθ(ϕs)Dθ(ϕt)22+λ(ϕ^sDθ(ϕ^s)21)2l_{{dis}} = \left\lVert D_\mathbf{\theta}(\phi_{s}) - D_\mathbf{\theta}(\phi_{t}) \right\lVert_2^{2} + \lambda\bigg(\left\lVert \nabla_{\hat\phi_{s}}D_\mathbf{\theta}(\hat\phi_{s}) \right\lVert_2 - 1\bigg)^{2}

where λ\lambda controls gradient penalty regularization. The joint deformation ϕ^s\hat\phi_{s} is computed from the teacher deformation ϕt\phi_{t} and the predicted student deformation ϕs\phi_{s} as follow:

ϕ^s=βϕt+(1β)ϕs\hat\phi_{s} = \beta \phi_{t} + (1 - \beta) \phi_{s}

where β\beta control the effect of the teacher deformation.

In Discrimination Loss, DθD_\mathbf{\theta} is the discriminator, formed by a neural network with learnable parameters θ{\theta}. The details of DθD_\mathbf{\theta} is shown in Figure 3. In particular, DθD_\mathbf{\theta} consists of six 3D3D convolutional layers, the first layer is 128×128×128×3128 \times 128 \times 128 \times 3 and takes the c×c×c×1c \times c \times c \times 1 deformation as input. cc is equaled to the scaled size of the input image. The second layer is 64×64×64×1664 \times 64 \times 64 \times 16. From the second layer to the last convolutional layer, each convolutional layer has a bank of 4×4×44 \times 4 \times 4 filters with strides of 2×2×22 \times 2 \times 2, followed by a ReLU activation function except for the last layer which is followed by a sigmoid activation function. The number of output channels of the convolutional layers starts with 1616 at the second layer, doubling at each subsequent layer, and ends up with 256256.

Basically, this is to inject the condition information with a matched tensor dimension and then leave the network learning useful features from the condition input. The output of the last neural layer is the mean feature of the discriminator, denoted as MM. Note that in the discrimination loss, a gradient penalty regularization is applied to deal with critic weight clipping which may lead to undesired behavior in training adversarial networks.


Figure 3: The structure of the discriminator DθD_\mathbf{\theta} used in the Discrimination Loss (ldisl_{dis}) of our Adversarial Learning with Distilling Knowledge algorithm (Source).

Reconstruction Loss. The reconstruction loss lrecl_{rec} is an important part of a deformation estimator. Follow the VTN [3] baseline, the reconstruction loss is written as:

lrec(Imh,If)=1CorrCoef[Imh,If]l_{{rec}} (\textbf{\textit{I}}_m^h,\textbf{\textit{I}}_f) = 1 - CorrCoef [\textbf{\textit{I}}_m^h,\textbf{\textit{I}}_f]


CorrCoef[I1,I2]=Cov[I1,I2]Cov[I1,I1]Cov[I2,I2]CorrCoef[\textbf{\textit{I}}_1, \textbf{\textit{I}}_2] = \frac{Cov[\textbf{\textit{I}}_1,\textbf{\textit{I}}_2]}{\sqrt{Cov[\textbf{\textit{I}}_1,\textbf{\textit{I}}_1]Cov[\textbf{\textit{I}}_2,\textbf{\textit{I}}_2]}}
Cov[I1,I2]=1ωxωI1(x)I2(x)1ω2xωI1(x)yωI2(y)Cov[\textbf{\textit{I}}_1, \textbf{\textit{I}}_2] = \frac{1}{|\omega|}\sum_{x \in \omega} \textbf{\textit{I}}_1(x)\textbf{\textit{I}}_2(x) - \frac{1}{|\omega|^{2}}\sum_{x \in \omega} \textbf{\textit{I}}_1(x)\sum_{y \in \omega}\textbf{\textit{I}}_2(y)

where CorrCoef[I1,I2]CorrCoef[\textbf{\textit{I}}_1, \textbf{\textit{I}}_2] is the correlation between two images I1\textbf{\textit{I}}_1 and I2\textbf{\textit{I}}_2, Cov[I1,I2]Cov[\textbf{\textit{I}}_1, \textbf{\textit{I}}_2] is the covariance between them. ω\omega denotes the cuboid (or grid) on which the input images are defined.

Learning Strategy. The forward and backward of the aforementioned Ladv\mathcal{L}_{adv} is controlled by the adversarial learning strategy described in Algorithm 1.

In our deformable registration setup, the role of real data and attacking data is reversed when compared with the traditional adversarial learning strategy. In adversarial learning, the model uses unreal (generated) images as attacking data, while image labels are ground truths. However, in our deformable registration task, the model leverages the unreal (generated) deformations from the teacher as attacking data, while the image is the ground truth for the model to reconstruct the input information. As a consequence, the role of images and the labels are reversed in our setup. Since we want the information to be learned more from real data, the generator will need to be considered more frequently. Although the knowledge in the discriminator is used as attacking data, the information it supports is meaningful because the distilled information is inherited from the high-performed teacher model. With these characteristics of both the generator and discriminator, the light-weight student network is expected to learn more effectively and efficiently.


[1] S. Zhao, Y. Dong, E. I. Chang, Y. Xu, et al., Recursive cascaded networks for unsupervised medical image registration, in ICCV, 2019.

[2] G. Hinton, O. Vinyals, and J. Dean, Distilling the knowledge in a neural network, ArXiv, 2015.

[3] S. Zhao, T. Lau, J. Luo, I. Eric, C. Chang, and Y. Xu, Unsupervised 3d end-to-end medical image registration with volume tweening network, IEEE J-BHI, 2019.

Open Source

🐱 Github: https://github.com/aioz-ai/LDR_ALDK

Light-weight Deformable Registration using Adversarial Learning with Distilling Knowledge

Introduction: Medical image registration

Medical image registration is the process of systematically placing separate medical images in a common frame of reference so that the information they contain can be effectively integrated or compared. Applications of image registration include combining images of the same subject from different modalities, aligning temporal sequences of images to compensate for the motion of the subject between scans, aligning images from multiple subjects in cohort studies, or navigating with image guidance during interventions. Since many organs do deform substantially while being scanned, the rigid assumption can be violated as a result of scanner-induced geometrical distortions that differ between images. Therefore, performing deformable registration is an essential step in many medical procedures.

Previous Studies, Remaining Challenges, and Motivation

Recently, learning-based methods have become popular to tackle the problem of deformable registration. These methods can be split into two groups: (i) supervised methods that rely on the dense ground-truth flows obtained by either traditional algorithms or simulating intra-subject deformations. Although these works achieve state-of-the-art performance, they require a large amount of manually labeled training data, which are expensive to obtain; and (ii) unsupervised learning methods that use a similarity measurement between the moving and the fixed image to utilize a large amount of unlabelled data. These unsupervised methods achieve competitive results in comparison with supervised methods. However, their deformations are reconstructed without the direct ground-truth guidance, hence leading to the limitation of leveraging learnable information. Furthermore, recent unsupervised methods all share an issue of great complexity as the network parameters increase significantly when multiple progressive cascades are taken into account. This leads to the fact that these works can not achieve real-time performance during inference while requiring intensively computational resources when deploying.

In practice, there are many scenarios when medical image registration are needed to be fast - consider matching preoperative and intra-operative images during surgery, interactive change detection of CT or MRI data for a radiologist, deformation compensation or 3D alignment of large histological slices for a pathologist, or processing large amounts of images from high-throughput imaging methods. Besides, in many image-guided robotic interventions, performing real-time deformable registration is an essential step to register the images and deal with organs that deform substantially. Economically, the development of a CPU-friendly solution for deformable registration will significantly reduce the instrument costs equipped for the operating theatre, as it does not require GPU or cloud-based computing servers, which are costly and consume much more power than CPU. This will benefit patients in low- and middle-income countries, where they face limitations in local equipment, personnel expertise, and budget constraints infrastructure. Therefore, design an efficient model which is fast and accurate for deformable registration is a crucial task and worth for study in order to improve a variety of surgical interventions.


Deformable registration is a crucial step in many medical procedures such as image-guided surgery and radiation therapy. Most recent learning-based methods focus on improving the accuracy by optimizing the non-linear spatial correspondence between the input images. Therefore, these methods are computationally expensive and require modern graphic cards for real-time deployment. Thus, we introduce a new Light-weight Deformable Registration network that significantly reduces the computational cost while achieving competitive accuracy (Fig.1). In particular, we propose a new adversarial learning with distilling knowledge algorithm that successfully leverages meaningful information from the effective but expensive teacher network to the student network. We design the student network such as it is light-weight and well suitable for deployment on a typical CPU. The extensively experimental results on different public datasets show that our proposed method achieves state-of-the-art accuracy while significantly faster than recent methods. We further show that the use of our adversarial learning algorithm is essential for a time-efficiency deformable registration method.


Figure 1: Comparison between typical deep learning-based methods for deformable registration (a) and our approach using adversarial learning with distilling knowledge for deformable registration (b). In our work, the expensive Teacher Network is used only in training; the Student Network is light-weight and inherits helpful knowledge from the Teacher Network via our Adversarial Learning algorithm. Therefore, the Student Network has high inference speed, while achieving competitive accuracy (Source).


Method overview

We describe our method for Light-weight Deformable Registration using Adversarial Learning with Distilling Knowledge. Our method is composed of three main components: (i)) a Knowledge Distillation module which extracts meaningful deformations ϕt\bm{\phi_t} from the Teacher Network; (ii) a Light-weight Deformable Registration (LDR) module which outputs a high-speed Student Network; and (iii) an Adversarial Learning with Distilling Knowledge (ALDK) algorithm which effectively leverages teacher deformations ϕt\bm{\phi}_t to the student deformations. An overview of our proposed deformable registration method can be found in Fig.2.


Figure 2: An overview of our proposed Light-weight Deformable Registration (LDR) method using Adversarial Learning with Distilling Knowledge (ALDK). Firstly, by using knowledge distillation, we extract the deformations from the Teacher Network as meaningful ground-truths. Secondly, we design a light-weight student network, which has competitive speed. Finally, We employ the Adversarial Learning with Distilling Knowledge algorithm to effectively transfer the meaningful knowledge of distilled deformations from the Teacher Network to the Student Network (Source).

Since the content may over-length, in this part, we introduce the background theory for Deformable Registration and Knowledge Distillation for Deformation. In the next part, we will introduce the Architecture of Light-weight Deformable Registration Network and Adversarial Learning Algorithm with Distilling Knowledge. In the final part, we will introduce the effectiveness of the method in comparison with recent states of the arts and detailed analysis.

Background: Deformable Registration

We follow RCN [1] to define deformable registration task recursively using multiple cascades. Let Im,If\textbf{\textit{I}}_m, \textbf{\textit{I}}_f denote the moving image and the fixed image respectively, both defined over dd-dimensional space Ω\bm{\Omega}. A deformation is a mapping ϕ:ΩΩ\bm{\phi} : \bm{\Omega} \rightarrow \bm{\Omega}. A reasonable deformation should be continuously varying and prevented from folding. The deformable registration task is to construct a flow prediction function F\textbf{F} which takes Im,If\textbf{\textit{I}}_m, \textbf{\textit{I}}_ f as inputs and predicts a dense deformation ϕ\bm{\phi} that aligns Im\textbf{\textit{I}}_m to If\textbf{\textit{I}}_f using a warp operator \circ as follows:

F(n)(Im(n1),If)=ϕ(n)F(n1)(ϕ(n1)Im(n2),If)\textbf{F}^{(n)}(\textbf{\textit{I}}^{(n-1)}_m,\textbf{\textit{I}}_f)=\phi^{(n)} \circ \textbf{F}^{(n-1)}(\phi^{(n-1)} \circ \textbf{\textit{I}}^{(n-2)}_m,\textbf{\textit{I}}_f)

where F(n1)\textbf{F}^{(n-1)} is the same as F(n)\textbf{F}^{(n)}, but in a different flow prediction function. Assuming for nn cascades in total, the final output is a composition of all predicted deformations, i.e.,

F(Im,If)=ϕ(n)...ϕ(1),\textbf{F}(\textbf{\textit{I}}_m, \textbf{\textit{I}}_f)=\phi^{(n)} \circ...\circ \phi^{(1)},

and the final warped image is constructed by

Im(n)=F(Im,If)Im\textbf{\textit{I}}_{m}^{(n)}=\textbf{F}(\textbf{\textit{I}}_m,\textbf{\textit{I}}_f) \circ \textbf{\textit{I}}_m

In general, previous Equations form the hypothesis function F\mathcal{F} under the learnable parameter W\mathbf{W},

F(Im,If,W)=(vϕ,Im(n))\mathcal{F}(\textbf{\textit{I}}_{m}, \textbf{\textit{I}}_f, \mathbf{W}) = (\mathbf{v}_{\phi}, \textbf{\textit{I}}_m^{(n)})

where vϕ=[ϕ(1),ϕ(2),...,ϕ(k),...,ϕ(n)]\mathbf{v}_{\phi} = [\bm{\phi}^{(1)}, \bm{\phi}^{(2)}, ..., \bm{\phi}^{(k)},..., \bm{\phi}^{(n)}] is a vector containing predicted deformations of all cascades. Each deformation ϕ(k)\bm{\phi}^{(k)} can be computed as

ϕ(k)=F(k)(Im(k1),If,Wϕ(k))\bm{\phi}^{(k)} = {\mathcal{F}}^{(k)}\left(\textbf{\textit{I}}_{m}^{(k-1)}, \textbf{\textit{I}}_f, \mathbf{W}_{\phi^{(k)}}\right)

To estimate and achieve a good deformation, different networks are introduced to define and optimize the learnable parameter W\mathbf{W}.

Knowledge Distillation for Deformation

Knowledge distillation is the process of transferring knowledge from a cumbersome model (teacher model) to a distilled model (student model). The popular way to achieve this goal is to train the student model on a transfer set using a soft target distribution produced by the teacher model.

Different from the typical knowledge distillation methods that target the output softmax of neural networks as the knowledge, in the deformable registration task, we leverage the teacher deformation ϕt\bm{\phi}_t as the transferred knowledge. As discussed in [2], teacher networks are usually high-performed networks with good accuracy. Therefore, our goal is to leverage the current state-of-the-art Recursive Cascaded Networks (RCN) [1] as the teacher network for extracting meaningful deformations to the student network. The RCN network contains an affine transformation and a large number of dense deformable registration sub-networks designed by VTN [3]. Although the teacher network has expensive computational costs, it is only applied during the training and will not be used during the inference.


[1] S. Zhao, Y. Dong, E. I. Chang, Y. Xu, et al., Recursive cascaded networks for unsupervised medical image registration, in ICCV, 2019.

[2] G. Hinton, O. Vinyals, and J. Dean, Distilling the knowledge in a neural network, ArXiv, 2015.

[3] S. Zhao, T. Lau, J. Luo, I. Eric, C. Chang, and Y. Xu, Unsupervised 3d end-to-end medical image registration with volume tweening network, IEEE J-BHI, 2019.

Open Source

🐱 Github: https://github.com/aioz-ai/LDR_ALDK

Medical Visual Question Answering Challenges

Visual Question Answering in Medical Domain

Visual Question Answering (VQA) aims to provide a correct answer for a given question consistent with the visual content of a given image. The overarching goal of this issue is to create systems that can comprehend the contents of an image in the same way that humans do and communicate effectively about that image in natural language. It is indeed a challenging task as it necessitates the interaction and complementation of both image feature extractor and natural language processor.

In medical domain, VQA could benefit both doctors and patients. VQA systems capable of understanding medical images and answering questions related to their content could support clinical education, clinical decision, and patient education. For example, doctors could use answers provided by VQA system as support materials in decision making. It can also help doctors to get a second opinion in diagnosis and reduce the high cost of training medical professionals. While patients could ask VQA questions related to their medical images for better understanding their health.

Compared with VQA in the general domain, Med-VQA is a much more challenging problem. On one hand, clinical questions are more difficult but need to be answered with higher accuracy since they relate to health and safety


Figure 1: An example of Medical VQA (Source).

There are challenges that need to be addressed when using VQA in Medical domain:

  • Lack of large scale labeled datasets.
  • Using Transfer learning in medical domain.

In this blog, we will analyze this challenges.

Challenges in Medical VQA

1. Lack of large scale labeled datasets.

One major problem with medical VQA is the lack of large scale labeled training data which usually requires huge efforts to build. Difficulties in building a medical VQA dataset

  • (i) designing goal-oriented VQA systems and datasets
  • (ii) categorizing the clinical questions
  • (iii) selecting (clinically) relevant images
  • (iv) capturing the context and the medical knowledge.

The first attempt for building the dataset for medical VQA is by ImageCLEF-Med. In this, images were automatically captured from PubMed Central articles. The questions and answers were automatically generated from corresponding captions of images. By that construction, the data has high noisy level, i.e., the dataset includes many images that are not useful for direct patient care and it also contains questions that do not make any sense.

Well-annotated datasets for training Medical VQA systems are extremely lacking, but it is very laborious and expensive to obtain high-quality annotations by medical experts. The first manually constructed VQA-RAD dataset for medical VQA task is released. Unfortunately, it contains only 315 images, which prevents to directly apply the powerful deep learning models for the VQA problem.

2. Using Transfer learning in a specific domain.

Transfer learning is an important step to extract meaningful features and overcome the data limitation in the medical Visual Question Answering (VQA) task. Transfer learning, in which the pretrained deep learning models that are trained on the large scale labeled dataset such as ImageNet, is a popular way to initialize the feature extraction process. However, due to the difference in visual concepts between ImageNet images and medical images, finetuning process is not sufficient. Recently, Model Agnostic Meta-Learning (MAML) has been introduced to overcome the aforementioned problem by learning meta-weights that quickly adapt to visual concepts. However, MAML is heavily impacted by the meta-annotation phase for all images in the medical dataset. Different from normal images, transfer learning in medical images is more challenging due to:

  • (i) noisy labels may occur when labeling images in an unsupervised manner;
  • (ii) high-level semantic labels cause uncertainty during learning;
  • (iii) difficulty in scaling up the process to all unlabeled images in medical datasets.

Multiple Meta-model Quantifying for Medical Visual Question Answering


A medical Visual Question Answering (VQA) system can provide meaningful references for both doctors and patients during the treatment process. Extracting image features is one of the most important steps in a medical VQA framework which outputs essential information to predict answers. Transfer learning, in which the pretrained deep learning models that are trained on the large scale labeled dataset such as ImageNet, is a popular way to initialize the feature extraction process. However, due to the difference in visual concepts between ImageNet images and medical images, finetuning process is not sufficient. Recently, Model Agnostic Meta-Learning (MAML) has been introduced to overcome the aforementioned problem by learning meta-weights that quickly adapt to visual concepts. However, MAML is heavily impacted by the meta-annotation phase for all images in the medical dataset. Different from normal images, transfer learning in medical images is more challenging due to:

  • (i) noisy labels may occur when labeling images in an unsupervised manner;
  • (ii) high-level semantic labels cause uncertainty during learning;
  • (iii) difficulty in scaling up the process to all unlabeled images in medical datasets.

Federated Learning, Challenges and Hot Trends for research

Federated learning and its remaining challenge

Federated learning is the process of training statistical models via a network of distant devices or siloed data centers, such as mobile phones or hospitals, while keeping data locally. In terms of federated learning, there are five major obstacles that have a direct impact on the paper publishing trend.

1. Expensive Communication

Due to the internet connection, huge number of users, and administrative costs, there is a bottleneck in communication between devices and server-devices.

2. Systems Heterogeneity

Because of differences in hardware (CPU, RAM), network connection (3G, 4G, 5G, wifi), and power, each device in federated networks may have different storage, computational, and communication capabilities (battery level).

3. Statistical Heterogeneity

Devices routinely create and collect data in non-identically dispersed ways across the network; for example, in the context of a next word prediction, mobile phone users employ a variety of languages. Furthermore, the quantity of data points on different devices may differ greatly, and an underlying structure may exist that describes the interaction between devices and their related distributions. This data generation paradigm violates frequently-used independent and identically distributed (I.I.D. problem) assumptions in distributed optimization, increases the likelihood of stragglers, and may add complexity in terms of modeling, analysis, and evaluation.

4. Privacy Concerns

In federated learning applications, privacy is a crucial problem. Federated learning takes a step toward data protection by sharing model changes, such as gradient information, rather than the raw data created on each device. Nonetheless, transmitting model updates during the training process may divulge sensitive information to a third party or the central server.

5. Domain transfer

Not any task can be applied to the federated learning paradigm to finish their training process due to the aforementioned four challenges.

Hot trends

Data distribution heterogeneity and label inadequacy.

  • Distributed Optimization
  • Non-IID and Model Personalization
  • Semi-Supervised Learning
  • Vertical Federated Learning
  • Decentralized FL
  • Hierarchical FL
  • Neural Architecture Search
  • Transfer Learning
  • Continual Learning
  • Domain Adaptation
  • Reinforcement Learning
  • Bayesian Learning

Security, privacy, fairness, and incentive mechanisms:

  • Adversarial-Attack-and-Defense
  • Privacy
  • Fairness
  • Interpretability
  • Incentive Mechanism

Communication and computational resource constraints, software and hardware heterogeneity, the FL system

  • Communication-Efficiency
  • Straggler Problem
  • Computation Efficiency
  • Wireless Communication and Cloud Computing
  • FL System Design

Models and Applications

  • Models
  • Natural language Processing
  • Computer Vision
  • Health Care
  • Transportation
  • Recommendation System
  • Speech
  • Finance
  • Smart City
  • Robotics
  • Networking
  • Blockchain
  • Other

Benchmark, Dataset, and Survey

  • Benchmark and Dataset
  • Survey

Deep Metric Learning Meets Deep Clustering - An Novel Unsupervised Approach for Feature Embedding


The objective of deep distance metric learning (DML) is to train a deep learning model that maps training samples into feature embeddings that are close together for samples that belong to the same category and far apart for samples from different categories. Traditional DML approaches require supervised information, i.e., class labels, to supervise the training. Although the supervised DML achieves impressive results on different tasks, it requires large amount of annotated training samples to train the model. Unfortunately, such large datasets are not always available and they are costly to annotate for specific domains. That disadvantage also limits the transferability of supervised DML to new domain/applications which do not have labeled data. These reasons have motivated recent studies aiming at learning feature embeddings without annotated datasets --- unsupervised deep distance metric learning (UDML). Our study is in that same direction, i.e., learning embeddings from unlabeled data.

There are two main challenges for UDML:

  • Firstly, how to define positive and negative samples for a given anchor data point, such that we can apply distance-based losses, e.g., pairwise loss or triplet loss, in the embedding space.
  • Secondly, how to make the training efficient, given a large number of pairs or triplets of samples, in the order of O(N2)\mathcal{O}(N^2) or O(N3)\mathcal{O}(N^3), respectively, in which NN is the number of training samples.

In this paper, we propose a new method that utilizes deep clustering for deep metric learning to address the two challenges mentioned above. In particular,

  • We propose to use a deep clustering loss to learn centroids, i.e., pseudo labels, that represent semantic classes.
  • During learning, these centroids are also used to reconstruct the input samples. It hence ensures the representativeness of centroids — each centroid represents visually similar samples. Therefore, the centroids give information about positive (visually similar) and negative (visually dissimilar) samples.
  • Based on pseudo labels, we propose a novel unsupervised metric loss which enforces the positive concentration and negative separation of samples in the embedding space.



Figure 1: Illustration of the proposed framework which consists of an encoder (G), an embedding module (F), a decoder (D) and three losses, i.e., clustering loss LrimL_{rim}, reconstruction loss LrecL_{rec} and metric loss LmL_m. The details are presented in the text (Source).

The proposed framework is presented in Figure 1.

  • For every original image in a batch, we make an augmented version by using a random geometric transformation.
  • The input images are fed into the backbone network which is also considered as the encoder (G) to get image representations.
  • The image representations are passed through the embedding module which consists of fully connected and L2 normalization layers (F), which results in unit norm image embeddings.
  • The clustering module takes image embeddings as inputs, performs the clustering with a clustering loss, and outputs the cluster assignments.
  • Given the cluster assignments, centroid representations are computed from image representations, which are then passed through the decoder (D) with a reconstruction loss to reconstruct images that belong to the corresponding clusters.
  • The centroid representations are also passed through the embedding module (F) to get centroid embeddings. The centroid embeddings and image embeddings are used as inputs for the metric loss.

Discriminative Clustering

We formulate the clustering of embedding features as a classification problem. Given a set of embedding features X={xi}i=1mR128×mX=\{x_i\}_{i=1}^m \in \mathbb{R}^{128 \times m} in a batch and the number of clusters KmK\le m (i.e., the number of clusters KK is limited by the batch size mm). The cluster assignment for xx then is estimated by c=argmaxcyc^* = argmax_c y. Let Y={yi}i=1mY=\{y_i\}_{i=1}^m be the set of softmax outputs for XX.

Inspired by Regularized Information Maximization (RIM), we use the following objective function (1) for the clustering.

Lrim=R(θ)λ[H(Y)H(YX)](1)L_{rim} = \mathcal{R}(\theta) - \lambda \left[ H(Y) - H(Y|X) \right] (1)

where H(.)H(.) and H(..)H(.|.) are entropy and conditional entropy, respectively; R(θ)\mathcal{R}(\theta) regularizes the classifier parameters (in this work we use l2l_2 regularization); λ\lambda is a weighting factor to control the importance of two terms.

Minimizing (1) is equivalent to maximizing H(Y)H(Y) and minimizing H(YX)H(Y|X). Increasing the marginal entropy H(Y)H(Y) encourages cluster balancing, while decreasing the conditional entropy H(YX)H(Y|X) encourages cluster separation.


In order to enhance the representativeness of centroids, we introduce a reconstruction loss (2) that penalizes high reconstruction errors from centroids to corresponding samples. Specifically, the decoder takes a centroid representation of a cluster and minimizes the difference between input images that belong to the cluster and the reconstructed image from the centroid representation.

Lrec=1mj=1KIiXjIiD(rj)2,(2)L_{rec} = \frac{1}{m} \sum_{j=1}^K\sum\limits_{I_i\in X_j}||I_i - D(r_j)||^2, (2)

where D(.)D(.) is the decoder which reconstructs samples in the batch using their corresponding centroid representations and mm is the number of images in the batch.

Metric loss

Let fiR128f_i \in \mathbb{R}^{128} and fi^R128\hat{f_i} \in \mathbb{R}^{128} be the image embeddings of IiI_i and Ii^\hat{I_i}, respectively. The proposed metric loss (3) aims to minimize the distance between fif_i and fi^\hat{f_i} while pushing fif_i far away from negative clusters.

Lm=ilog(l(Ii,Ii^))ij=1,jqKlog(l(Ii,cj))(3)L_{m} = -\sum_i log \left( l(I_i,\hat{I_i}) \right) - \sum_i\sum\limits_{j=1, j \neq q}^K log \left(l(I_i,c_j) \right) (3)

Final Loss

The network in Figure 1 is trained in an end-to-end manner with the following multi-task loss.

L=αLm+βLrim+γLrec,(4)L = \alpha L_{m} + \beta L_{rim} + \gamma L_{rec}, (4)

where LmL_m is the center-based softmax loss (3) for deep metric learning, LrimL_{rim} is the clustering loss (1), and LrecL_{rec} is the reconstruction loss (2).

Experimental Results

Ablation Study

We denote our model with:

  • only clustering loss (1) as only LrimL_{rim}.
  • both clustering and the metric losses (2) and (3) as Center-based Softmax (CBS).
  • Center-based Softmax with Reconstruction (CBSwR).

Table 1: The impact of each loss component on the performance on CUB200-2011 dataset and the comparison to the baseline (Source).


Table 2: The impact of each loss component on the performance on Car196 dataset and the comparison to the baseline (Source).

Tables 1 and 2 present the comparative results between methods. The results show that using only the clustering loss, the accuracy is significantly lower than the baseline SME. However, when using the centroids from the clustering for calculating the metric loss (i.e., CBS), it gives the performance boost over the baseline (i.e., SME). Furthermore, the reconstruction loss enhances the representativeness of centroids, as confirmed by the improvements of CBSwR over CBS on both datasets.


Table 3: The training time (seconds) of different methods on CUB200-2011 and Car196 datasets with 20 epochs. The models are trained on a NVIDIA GeForce GTX 1080-Ti GPU (Source).


Table 4: The impact of the number of clusters of the final model CBSwR on the performance on CUB200-2011 dataset (Source).

Table 3 presents the training time of different methods on the CUB200-2011 and Car196 datasets. Although the asymptotic complexity of CBSwR for training one batch is O(Km)\mathcal{O}(Km), it also consists of a decoder part which affects the real training. It is worth noting that the decoder is only involved during training. During testing, our method has similar computational complexity as SME.

Table 4 presents the impact of the number of clusters KK in the clustering loss on the CUB200-2011 dataset with our proposed model CBSwR (recall that the number of clusters KK is limited by the batch size mm). During training, the number of samples per clusters vary depending on batches and the number of clusters. At K=32K=32 which is our final setting, the number of samples per cluster varies from 2 to 11, on the average. The retrieval performance is just slightly different for the different number of clusters. This confirms the robustness of the proposed method w.r.t. the number of clusters.

Comparison to the state of the art


Table 5: Clustering and Recall performance on the CUB200-2011 dataset (Source).


Table 6: Clustering and Recall performance on the Car196 dataset (Source).

Table 5 presents the comparative results on CUB200-2011 dataset. In terms of clustering quality (NMI metric), the proposed method and the state-of-the-art UDML methods MOM and SME achieve comparable accuracy. However, in terms of retrieval accuracy R@K, our method outperforms other approaches. Our proposed method is also competitive to most of the supervised DML methods.

Table 6 presents comparative results on Car196 dataset. Compared to unsupervised methods, the proposed method outperforms other approaches in terms of retrieval accuracy at all ranks of K. Our method is comparable to other unsupervised methods in terms of clustering quality.

Quantity Visualization


Figure 2: Barnes-Hut t-SNE visualization of our embedding on the CUB200-2011 dataset.

Figure 2 shows the t-SNE plots on our learned embedding features on CUB200-2011. We can see that our embedding produces reasonable results in grouping similar visual objects despite the significant variations in view-point, pose, and configuration.


We propose a new method that utilizes deep clustering for deep metric learning to address the two challenges in UDML, i.e., positive/negative mining and efficient training. The method is based on a novel loss that consists of a learnable clustering function, a reconstruction function, and a center-based metric loss function. Our experiments on CUB200-2011 and Car196 datasets show state-of-the-art performance on the retrieval task, compared to other unsupervised learning methods.

Open Source

🐱 Github: https://github.com/aioz-ai/BMVC20_CBSwR

Multiple interaction learning with question-type prior knowledge for constraining answer search space in visual question answering.

Different approaches have been proposed to Visual Question Answering (VQA). However, few works are aware of the behaviors of varying joint modality methods over question type prior knowledge extracted from data in constraining answer search space, of which information gives a reliable cue to reason about answers for questions asked in input images. In this blog, we share a novel VQA model that utilizes the question-type prior information to improve VQA by leveraging the multiple interactions between different joint modality methods based on their behaviors in answering questions from different types. The solid experiments on two benchmark datasets, i.e., VQA 2.0 and TDIUC, indicate that the proposed method yields the best performance with the most competitive approaches.


There are works that consider types of question as the side information whichgives a strong cue to reason about the answer. However, the relation between question types and answers from training data have not been investi-gated yet. Fig. 1 shows the correlation between question types and some answersin the VQA 2.0 dataset. It suggests that a question regarding the quantityshould be answered by a number, not a color. The observation indicated that theprior information got from the correlations between question types and answers open an answer search space constrain for the VQA model. The search spaceconstrain is useful for VQA model to give out final prediction and thus, improvethe overall performance. The Fig. 1 is consistent with our observation, e.g., itclearly suggests that a question regarding the quantity should be answered by anumber, not a color.

Figure 1. The distribution of candidate answers in each question type in VQA 2.0.
Although different joint modality methods or attention mechanisms have been proposed, we hypothesize that each method may capture different aspects of the input. That means different attentions may provide different answers for questions belonged to different question types.
Fig.2 shows examples in which the attention models (SAN and BAN) attend on different regions of input images when dealing with questions from different types. Unfortunately, most of recent VQA systems are based on single attention models BAN2, SAN, MLP, MCB, STL. From the above observation, it is necessary to develop a VQA system which leverages the power of different attention models to deal with questions from different question types.
Figure 2. Examples of attention maps of different attention mechanisms. BAN and SAN identify different visual areas when answering questions from different types.


The proposed multiple interaction learning with question-type prior knowledge (MILQT) is illustrated in Fig. 3. Similar to the most of the VQA systems, multiple interaction learning with question-type prior knowledge (MILQT) consists of the joint learning solution for input questions and images, followed by a multi-class classification over a set of predefined candidate answers. However, MILQT allows to leverage multiple joint modality methods under the guiding of question-types to output better answers.

Figure 3. The introduced MILQT for VQA.
As in Fig.3, MILQT consists of two modules: Question-type awareness A\mathcal{A}, and Multi-hypothesis interaction learning M\mathcal{M}. The first module aims to learn the question-type representation, which is further used to enhance the joint visual-question embedding features and to constrain answer search space through prior knowledge extracted from data. Based on the question-type information, the second module aims to identify the behaviors of multiple joint learning methods and then justify adjust contributions to giving out final predictions.

Question representation. Given an input question, follow the recent state-of-the-art BAN, we trim the question to a maximum of 12 words. The questions that are shorter than 12 words are zero-padded. Each word is then represented by a 600-D vector that is a concatenation of the 300-D GloVe word embedding and the augmenting embedding from training data. This step results in a sequence of word embeddings with size of 12×60012 \times 600 and is denoted as fwf_w. In order to obtain the intent of question, the fwf_w is passed through a Gated Recurrent Unit (GRU) which results in a 1024-D vector representation fqf_q for the input question.

Image representation. We use bottom-up attention, i.e. an object detection which takes as FasterRCNN backbone, to extract image representation. At first, the input image is passed through bottom-up networks to get K×2048K \times 2048 bounding box representation which is denotes as fvf_v in Fig. 3.

Multi-level multi-modal fusion. Unlike the previous works that perform only one level of fusion between linguistic and visual features that may limit the capacity of these models to learn a good joint semantic space. In our work, a multi-level multi-modal fusion that encourages the model to learn a better joint semantic space is introduced which takes the question-type representation got from question-type classification component as one of inputs.

  • First level multi-modal fusion: The first level fusion is similar to previous works. Given visual features fvf_v, question features fqf_{q}, and any joint modality mechanism,
    we combines visual features with question features and learn attention weights to weight for visual and/or linguistic features. Different attention mechanisms have different ways for learning the joint semantic space. The output of first level multi-modal fusion is denoted as fattf_{att} in the Fig.3.
  • Second level multi-modal fusion: In order to enhance the joint semantic space, the output of the first level multi-modal fusion fattf_{att} is combined with the question-type feature fqtf_{qt}, which is the output of the last FC layer of the Question-type classification component.
    We try two simple but effective operators, i.e. element-wise multiplication --- EWM or element-wise addition --- EWA, to combine fattf_{att} and fqtf_{qt}. The output of the second level multi-modal fusion, which is denoted as fattqtf_{att-qt} in Fig.3, can be seen as an attention representation that is aware of the question-type information.

Given an attention mechanism, the fattqtf_{att-qt} will be used as the input for a classifier that predicts an answer for the corresponding question. This is shown at the Answer prediction boxes in the Fig.3.

Multi-hypothesis interaction learning As presented in Fig.3, MILQT allows to utilize multiple hypotheses (i.e., joint modality mechanisms). Specifically, we propose a multi-hypothesis interaction learning design M\mathcal{M} that takes answer predictions produced by different joint modality mechanisms and interactively learn to combine them.
Let gRA×Jg \in \R^{A \times J} be the matrix of predicted probability distributions over AA answers from the JJ joint modality mechanisms. M\mathcal{M} outputs the distribution ρRA\rho \in \R^{A}, which is calculated from gg through Equation below:

ρ=M(g,wmil)=j(mqtansTwmilg)\rho = \mathcal{M} \left(g,w_{mil}\right) = \sum_{j}\left(m^T_{qt-ans}w_{mil} \odot g\right)

wmilRP×Jw_ {mil} \in \textbf{R}^{P \times J} is the learnable weight which control the contributions of JJ considered joint modality mechanisms on predicting answer based on the guiding of PP question types; \odot denotes Hardamard product.


Experiments on VQA 2.0 test-dev and test-standard.

We evaluate MILQTon the test-dev and test-standard of VQA 2.0 dataset. To train the model,similar to previous works, we use both training set and validationset of VQA 2.0. We also use the Visual Genome as additional training data. MILQT consists of three joint modality mechanisms, i.e., BAN-2, BAN-2-Counter, and SAN accompanied with the EWM for the multi-modal fusion, andthe predicted question type together with the prior information to augment theVQA loss. Table 4 presents the results of different methods on test-dev and test-std of VQA 2.0. The results show that our MILQT yields the good performance with the most competitive approaches.


Table 1. Comparison to the state of the arts on the test-dev and test-standard of VQA 2.0. For fair comparison, Glove embedding and GRU are leveraged for question embedding and Bottom-up features are used to extract visual information. CMP, i.e.Cross-Modality with Pooling, is the LXMERT with the aforementioned setup (Source).

Experiments on TDIUC.

In order to prove the stability of MILQT, we evaluate MILQT on TDIUC dataset.
The results in Table.2 show that the proposed model establishes the state-of-the-art results on both evaluation metrics Arithmetic MPT and Harmonic MPT. Specifically, our model significantly outperforms the recent QTA, i.e., on the overall, we improve over QTA 6.1%6.1\% and 11.1%11.1\% with Arithemic MPT and Harmonic MPT metrics, respectively. It is worth noting that the results of QTA in Table. 2, which are cited from QTA, are achieved when QTA used the one-hot predicted question type of testing question to weight visual features. When using the groundtruth question type to weight visual features, QTA reported 69.11%69.11\% and 60.08%60.08\% for Arithemic MPT and Harmonic MPT metrics, respectively. Our model also outperforms these performances a large margin, i.e., the improvements are 3.9%3.9\% and 6.8%6.8\% for Arithemic MPT and Harmonic MPT metrics, respectively.


Table 2. The comparative results between the proposed model and other models onthe validation set of TDIUC (Source).


We present a multiple interaction learning with question-type prior knowledge for constraining answer search space--- MILQT that takes into account the question-type information to improve the VQA performance at different stages. The system also allows to utilize and learn different attentions under a unified model in an interacting manner. The extensive experimental results show that all proposed components improve the VQA performance. We yields the best performance with the most competitive approaches on VQA 2.0 and TDIUC dataset.

Open Source

Github: https://github.com/aioz-ai/ECCVW20_MILQT

A Brief Introduction to Visual Question Answering

1. Visual Question Answering - Overview

Visual Question Answering (VQA) aims to figure out a correct answer for a given question consistent with the visual content of a given image. The overarching goal of this issue is to create systems that can comprehend the contents of an image in the same way that humans do and communicate effectively about that image in natural language. It is indeed a challenging task as it necessitates the interaction and complementation of both image feature extractor and natural language processor.

There are two main variants of VQA which are Free-Form Opened-Ended (FFOE) VQA and Multiple Choice (MC) VQA. In FFOE VQA, an answer is a free-form response to a given image-question pair input, while in MC VQA, an answer is chosen from an answer list for a given image-question pair input. The discussion of VQA variants will be shared in the next post.

2. Approaches for solving VQA task.

There are three main approaches for VQA:

  • *Compositional VQA models:* the questions are interpreted as a set of many sub-tasks.
  • Bayesian and Question-Aware models: this method is not suitable for use in systems that respond to image-related questions. Since the algorithm based on this method does not try looking at the picture and instead predicts the response based on the Bayesian model by determining the probability of the words in the dataset's responses.
  • Attention based models: this method try to learn the interaction between image and question features in VQA task through a module called attention. Then, the joint features got from that module are leveraged for answering the corresponding question.

The final one is the most successful approach since recent states of the arts included attention mechanisms.

3. Attention based VQA approach.

In general, attention based VQA approaches have four main steps (See Figure 1):

  • Visual Representation: Encode the information from the image into vector(s) by using Convolutional Neural Network (CNN)
  • Textual Representation: Encode the information of question into vector(s) by using Embedding.
  • Joint Representation: A further step to learn the interaction between question(s) and image(s). Output joint features can be vector(s).
  • Answer prediction: the joint features from the previous step are then passed through this module to obtain the predicted answer. This module is mostly formed by a Classification.


Figure 1. The general approach for Visual Question Answering.

3.1. Visual Representation

The basic attributes or aspects that clearly help us recognize a specific object, image, or something are known as features. The distinguishing characteristics are the distinct properties. When operating on a VQA dataset, we must extract the features of various images in order to separate the images based on specific features or aspects. Image features are one of the most important pieces of information for a VQA system to output the correct answer.

Convolutional neural networks have emerged as the gold standard for image pattern recognition. An input image is converted into image features after it is passed through a convolutional network. Each filter in a CNN layer detects various patterns, such as corners, vertex, shapes, curves, and symmetries (See Figure 2).

Img Extraction

Figure 2. An example of feature extraction in VQA classification.

The majority of VQA literature employs CNNs for image processing. The network's final layer is removed, and the remaining network is used to extract image features. For image representation in VQA, objects in images represented by features extracted from an object detector such as the Faster-RCNN bottom-up model.

3.2. Textual Representation

Textual embeddings can be offered in a variety of ways. Count-based and frequency-based techniques such as count vectorization and TF-IDF are examples of older approaches. There are also prediction-based approaches such as a continuous bag of words and skip grams. Pretrained Word2Vec models are also openly accessible. Embeddings can also be created using deep learning architectures such as RNNs, LSTMs, GRUs, and 1-D CNNs. LSTMs are one of the most often used in VQA literature. For question embedding in VQA, Glove or BERT are used widely for capturing the representation of words and sentences in different contexts (See Figure 3 for a sample structure of question embedding).


Figure 3. An example of question embedding for VQA.

3.3. Joint Representation

In current VQA systems, the joint modality component plays an essential role since it would learn meaningful joint representations between linguistic and visual inputs by applying the attention mechanism. There are many works that learn the interaction between question and image. For instance, a novel trilinear interaction model which simultaneously learns high level associations between image, question and answer information- CTI (Do et al. 2018). See Figure 4 for more details.


Figure 4. Compact Trilinear Interaction mechanism for VQA (Source).

3.4. Answer Prediction

In most recent works, the joint features got from the attention mechanism is then passed through a classifier to output predicted answer. However, more modules can also be applied to produce external knowledge and deal with difficult questions. ````

Overcoming Data Limitation in Medical Visual Question Answering

What are the difficulties when dealing with Medical VQA task?

Visual Question Answering (VQA) aims to provide a correct answer to a given question such that the answer is consistent with the visual content of a given image.

In medical domain, VQA could benefit both doctors and patients. For example, doctors could use answers provided by VQA system as support materials in decision making, while patients could ask VQA questions related to their medical images for better understanding their health.


Figure 1: An example of Medical VQA (Source).

However, one major problem with medical VQA is the lack of large scale labeled training data which usually requires huge efforts to build.

  • The first attempt for building the dataset for medical VQA is by ImageCLEF-Med. In this, images were automatically captured from PubMed Central articles. The questions and answers were automatically generated from corresponding captions of images. By that construction, the data has high noisy level, i.e., the dataset includes many images that are not useful for direct patient care and it also contains questions that do not make any sense.
  • Recently, the first manually constructed VQA-RAD dataset for medical VQA task is released. Unfortunately, it contains only 315 images, which prevents to directly apply the powerful deep learning models for the VQA problem. One may think about the use of transfer learning in which the pretrained deep learning models that are trained on the large scale labeled dataset such as ImageNet are used for finetuning on the medical VQA. However, due to difference in visual concepts between ImageNet images and medical images, finetuning with very few medical images is not sufficient.

Therefore it is necessary to develop a new VQA framework that can improve the accuracy while still only needs a small labeled training data.

The motivation for our approach to overcome the data limitation of medical VQA comes from two observations:

  • Firstly, we observe that there are large scale unlabeled medical images available. These images are from same domain with medical VQA images. Hence if we train an unsupervised deep learning model using these unlabeled images, the trained weights may be easier to be adapted to the medical VQA problem than the pretrained weights on ImageNet images.
  • Another observation is that although the labeled dataset VQA-RAD is primarily designed for VQA, by spending a little effort, we can extract the new class labels for that dataset. The new class labels allow us to apply the recent meta-learning technique for learning meta-weights, that can be quickly adapted to the VQA problem later.


The proposed medical VQA framework is presented in Figure 2. In our framework, the image feature extraction component is initialized by pretrained weights from MAML and CDAE. After that, the VQA framework will be finetuned in an end-to-end manner on the medical VQA data. In the following sections, we detail the architectures of MAML, CDAE, and our framework.


Figure 2: The proposed medical VQA. The image feature extraction is denoted as 'Mixture of Enhanced Visual Features (MEVF)' and is marked with the red dashed box. The weights of MEVF are intialized by MAML and CDAE (Source).

Model-Agnostic Meta-Learning -- MAML

The MAML model consists of four 3×33\times3 convolutional layers with stride 22 and is ended with a mean pooling layer; each convolutional layer has 6464 filters and is followed by a ReLu layer.

We create the dataset for training MAML by manually reviewing around three thousand question-answer pairs from the training set of VQA-RAD dataset. In our annotation process, images are split into three parts based on its body part labels (head, chest, abdomen). Images from each body part are further divided into three subcategories based on the interpretation from the question-answer pairs corresponding to the images. These subcategories are: 1. normal images in which no pathology is found. 2. abnormal present images in which there are the existence of fluid, air, mass, or tumor. 3. abnormal organ images in which the organs are large in size or in wrong position.

Thus, all the images are categorized into 9 classes:

| head normal | head abnormal present | head abnormal organ |
| chest normal | chest abnormal organ | chest abnormal present |
| abdominal normal | abdominal abnormal organ | abdominal abnormal present |

For every iteration of MAML training (line 3 in Alg.1), 5 tasks are sampled per iteration. For each task, we randomly select 3 classes (from 9 classes). For each class, we randomly select 6 images in which 3 images are used for updating task models and the remaining 3 images are used for updating meta-model.


Denoising Auto Encoder -- CDAE

The encoder maps an image xx', which is the noisy version of the original image xx, to a latent representation zz which retains useful amount of information. The decoder transforms zz to the output yy. The training algorithm aims to minimize the reconstruction error between yy and the original image xx as follows

Lrec=xy22L_{rec} = \left \| x-y \right \|_2^2

In our design, the encoder is a stack of convolutional layers; each of them is followed by a max pooling layer. The decoder is a stack of deconvolutional and convolutional layers. The noisy version xx' is achieved by adding Gaussian noise to the original image xx.

To train CDAE, we collect 11,77911,779 unlabeled images available online which are brain MRI images, chest X-ray images and CT abdominal images. The dataset is split into train set with 9,4239,423 images and test set with 2,3562,356 images. We use Gaussian noise to corrupt the input images before feeding them to the encoder.

Our VQA framework

After training MAML and CDAE, we use their trained weights to initialize the MEVF image feature extraction component in the VQA framework. We then finetune the whole VQA model using the training set of VQA-RAD dataset.

To train the proposed model, we introduce a multi-task loss func-tion to incorporate the effectiveness of the CDAE to VQA. Formally, our lossfunction is defined as follows:

L=α1Lvqa+α2LrecL = \alpha_1 L_{vqa} + \alpha_2 L_{rec}

where LvqaL_{vqa} is a Cross Entropy loss for VQA classification and LrecL_{rec} stands for the reconstruction loss of CDAE . The whole VQA model is finetuned in an end-to-end manner.



Table 1: VQA results on VQA-RAD test set. All reference methods differ at the image feature extraction component. Other components are similar. The Stacked Attention Network (SAN) is used as the attention mechanism in all methods (Source).

Table 1 presents VQA accuracy in both VQA-RAD open-ended and close-ended questions on the test set. The results show that for both MAML and CDAE, by firstly pretraining then finetuning, the finetuning significantly improves the performance over the training from scratch using only VQA-RAD.

In addition, the results also show that our pretraining and finetuning of MAML and CDAE give better performance than the finetuning of VGG-16 which is pretrained on the ImageNet dataset. Our proposed image feature extraction MEVF which leverages both pretrained weights of MAML and CDAE, then finetuning them give the best performance. This confirms the effectiveness of the proposed MEVF for dealing with the limitation of labeled training data for medical VQA.


Table 2: Performance comparison on VQA-RAD test set (Source).

Table 2 presents comparative results between methods. Note that for the image feature extraction, the baselines use the pretrained models (VGG or ResNet) that have been trained on ImageNet and then finetune on the VQA-RAD dataset. For the question feature extraction, all baselines and our framework use the same pretrained models (i.e., Glove) and finetuning on VQA-RAD. The results show that when BAN or SAN is used as the attention mechanism in our framework, it significantly outperforms the baseline frameworks BAN and SAN. Our best setting, i.e. the one with BAN as the attention, achieves the state-of-the-art results and it significantly outperforms the best baseline framework BAN, i.e., the improvements are 16.3%16.3\% and 8.6%8.6\% on open-ended and close-ended VQA, respectively.


In this paper, we proposed a novel medical VQA framework that leverages the meta-learning MAML and denoising auto-encoder CDAE for image feature extraction in order to overcome the limitation of labeled training data. Specifically, CDAE helps to leverage information from the large scale unlabeled images, while MAML helps to learn meta-weights that can be quickly adapted to the VQA problem. We establish new state-of-the-art results on VQA-RAD dataset for both close-ended and open-ended questions.

Open Source

🐱 Github: https://github.com/aioz-ai/MICCAI19-MedVQA

Data Augmentation for Colon Polyp Detection: A systematic Study

Colorectal cancer (CRC)♋, also known as bowel cancer or colon cancer, is a cancer development from the colon or rectum called a polyp. Detecting polyps is a common approach in screening colonoscopies to prevent CRC at an early stage. Early colon polyp detection from medical images is still an unsolved problem due to the considerable variation of polyps in shape, texture, size, color, illumination, and the lack of publicly annotated datasets. At AIOZ, we adopt a recently proposed auto-augmentation method for polyp detection. We also conduct a systematic study on the performance of different data augmentation methods for colon polyp detection. The experimental results show that the auto-augmentation achieves the best performance comparing to other augmentation strategies.


Colorectal cancer (CRC) is the third-largest cause of worldwide cancer deaths in men and the second cause in women, with the number of patients, died each year up to 700,000 [1]. Detection and removal of colon polyps at an early stage will reduce the mortality from CRC. There are several methods for colon screening, such as CT colonography or wireless capsule endoscopy, but the gold standard is colonoscopy [2].

The colonoscopy is performed by an experienced doctor who uses a colonoscope to screen and scan for abnormalities such as intestinal signs, symptoms, colon cancer, and polyps. Abnormal polyps can be removed, and small amounts of tissue can be detached for analysis during the colonoscopy. However, the most crucial drawback of colonoscopy is polyp miss rate, especially with polyp more diminutive than10mm. Several factors cause the miss rate. They are both subjective factors such as bowel preparation, the specific choice of an endoscope, video processor, clinician skill, and objective factors such as polyp appearance and camera movement condition. For these reasons, automatic polyp detection is a potential approach to assist clinicians in improving the sensitivity of the diagnosis.

Previous research shows that automatic polyp detection using deep learning-based methods outperforms hand-craft-based methods demonstrated by both top two results in the MICCAI 2015 challenge [3]. For deep learning-based approaches and model architectures, data augmentation is also a critical factor in making significant improvements due to the lack of annotated data. The recent work [4]shows that learning an optimal policy from data for auto augmentation instead of hand-crafted defining data augmentation strategies can generalize objects better. Thus, studying auto augmentation for polyp detection problems is necessary. In this research, we adapt Faster R-CNN [5] together withAutoAugment [6] to detect polyp from colonoscopy video frames. Besides, we also evaluate traditional data augmentation [7] to see the effectiveness of different augmentation strategies.


1. Polyp Detector

Thanks to the power of deep learning, recent works [5, 12,11] show that deep-based detection methods give impressive detection performance. In this work, we use the Faster RCNN object detector [5] with Resnet101 [13] backbone pre-trained on COCO dataset. Our experiments show that this architecture gives a competitive performance on the Polyp detection problem. The experimental setting for the detector is set as follows. The network is trained using stochastic gradient descent (SGD)with 0.9 momentum; learning rate with initial value is set to3e-4 and will be decreased to 3e-5 from the iteration 900k.The number of anchor boxes per location used in our model is 12 (4 scales, i.e.,64×64, 128×128, 256×256, 512×512 and 3 aspect ratios, i.e.,1 : 2,1 : 1,2 : 1).

2. Data Augmentation


Fig. 1. Example of applying learned augmentation policies tocolonoscopy image.

Data augmentation can be split into two types: self-defined data augmentation (a.k.a traditional augmentation) and auto augmentation [6]. In this study, we adopt an automated data augmentation approach for object detection, i.e., Auto-augment [6], which finds optimal data augmentation policies during training. In Auto-augment, an augmentation policy consists of several sub-policies; each sub-policy consists of two operations. Each operation is an image transformation containing two parameters: probability and the magnitude of the shift. There are three types of transformations used in Auto-augment for object detection [4], which are

  • Color operations: distort color channels without impacting the locations of the bounding boxes
  • Geometric operations: geometrically distort the image, which correspondingly alters the location and size of the bounding box annotations
  • Bounding box operations: only distort the pixel content contained within the bounding box annotations. One of the essential conclusions in [4] is that the learned policy found on COCO can be directly applied to other detection datasets and models to improve predictive accuracy. Hence, in this study, we apply the learned policies from [4]to augment data when training the detector in Sec. 3.1. The learned policed we use for training our detector is summarized in Table 1. In Table 1, each operation is a triple which describes the transformation, the probability, and the magnitude of the transformations. Due to the space limitation, we refer the reader to [4] for detail on the descriptions of trans-formation. Fig. 1 showed augmented examples when applying the learned augmentation policy on a polyp image from the training dataset.


Table 1. Sub-policies and operations used in our experiment.

In addition to auto-augmentation, we also investigate the effect of traditional augmentation and the combination of traditional and automatic augmentation. We randomly apply several transformations to the image for traditional data augmentation, such as rotation, mirroring, sheering, translation, and zoom. We propose different strategies to combine those data augmentation types, i.e., (1) firstly, the detector is trained with Auto-augment; after that, it is trained with the traditional data augmentation; (2) training with the traditional augmentation, then with auto augmentation; (3) training with AutoAugmenton the original data and the data generated by traditional augmentation. All augmentation strategies are evaluated with the same model architecture and training configuration. This allows us to explore which data augmentation method is suitable for the polyp detection problem.


We use CVC-ClinicDB [14] for training and ETIS-Larib[15] for testing. This allows us to make a fair comparison with MICCAI2015 challenge results which are reported on the same dataset. The CVC-CLINIC database contains 612polyp image frames of 31 unique polyps from 31 different colonoscopy videos. The ETIS-LARIB dataset contains 196high resolution image frames of 44 different polyps.


Fig. 2. Examples of false positive detection on testing dataset. Green boxes and blues boxes are ground truths and predictions, respectively.


Fig. 3. Examples of false-negative detection on the testing dataset. Green boxes and blues boxes are ground truths and predictions, respectively.

Fig. 2 and Fig. 3 visualize several failed results from our model in the testing dataset in which the blue boxes are the predicted locations, and green boxes are ground truths. These false-positive samples (Fig. 2) caused by a shortcoming in bowel preparation (i.e., leftovers of food and fluid in the colon), while false negative (Fig. 3) samples are caused by the variations of polyp type and appearance (i.e., small polyp, flat polyp, similarities of polyp and colon vein)


Table 2. Comparison among traditional data augmentation (TDA), auto augmentation (AA) and their combinations.
Table 2 shows the comparative results between different augmentation strategies. The results show that the third combination method (AA-TDA-3) achieves higher performance in Precision than AutoAugment, i.e., 75.90% and 74.51%, respectively. However, overall, Auto-augment (AA) achieves the best results because of its performance in covering polyp miss rate (i.e., 152) with an acceptable false-positive rate (i.e., 52). The competitive performance of auto augmentation (AA) confirms the transferable learned data augmentation policies on the COCO dataset [4].


Table 3. Comparative results between our model and the state of the art.
Table 3 presents the comparative results between the auto augmentation in our model and other state-of-the-art results. Among compared methods, CUMED, OUR, and UNS- UCLAN are end-to-end deep learning-based approaches. The results show that compared to methods from MICCAI challenge [3], auto augmentation achieves better performance on all metrics. Comparing to the recent method [10], auto augmentation also achieves better performance on all metrics but FP. These results confirm the effectiveness of auto augmentation for polyp detection problems.


This study adopts a deep learning-based object detection method with auto data augmentation for polyp detection problems. Different augmentation strategies are evaluated. The experimental results show that the learned auto augmentation policies learned from the general object detection dataset are well transferred to the polyp detection problem. Although auto augmentation achieves competitive results, it still has a high FP compared to the state of the art. This weakness can be improved by several post-processing, such as false-positive learning.

Open Source

🍅 Github: https://github.com/aioz-ai/polyp-detection

🍓 Blog post: https://ai.aioz.io/blog/polyp-detection


This research was conducted by Phong Nguyen, Quang Tran, Erman Tjiputra, and Toan Do. We’d like to give special thanks to the other AIOZ AI team members for their supports and feedbacks.

🎉 All the above contributions were incredibly enabling for this research. 🎉


[1] Hamidreza Sadeghi Gandomani, Mohammad Aghajani, et al., “Colorectal cancer in the world: incidence, mortality and risk factors,”Biomedical Research and Therapy, 2017.

[2] Florence B ́enard, Alan N Barkun, et al., “Systematic review of colorectal cancer screening guidelines for average-risk adults: Summarizing the current global recommendations,”World journal of gastroenterology, 2018.

[3] Jorge Bernal, Nima Tajkbaksh, et al., “Comparative validation of polyp detection methods in video colonoscopy: results from the miccai 2015 endoscopic vision challenge,”IEEE Transactions on Medical Imaging, pp. 1231–1249, 2017.

[4] Barret Zoph, Ekin D Cubuk, et al., “Learning data augmentation strategies for object detection,”arXiv, 2019.

[5] Shaoqing Ren, Kaiming He, Ross Girshick, and Jian Sun, “Faster R-CNN: Towards real-time object detection with region proposal networks,” inNIPS, 2015.

[6] Ekin D Cubuk, Barret Zoph, et al.,“Autoaugment: Learning augmentation strategies from data,” in CVPR, 2019.

[7] Younghak Shin, Hemin Ali Qadir, et al., “Automatic colon polyp detection using region based deep cnn and post learning approaches,”IEEE Access, 2018

[8] Yangqing Jia, Evan Shelhamer, et al., “Caffe: Convolutional architecture for fast feature embedding,” in ACMMM, 2014.

Introduction to Federated Learning

Federated Learning: machine learning over a distributed dataset, where user devices (e.g., desktop, mobile phones, etc.) are utilized to collaboratively learn a shared prediction model while keeping all training data locally on the device. This approach decouples the ability to do machine learning from storing the data in the cloud.

Conceptually, federated learning proposes a mechanism to train a high-quality centralized model. Simultaneously, training data remains distributed over many clients, each with unreliable and relatively slow network connections.

The idea behind federated learning is as conceptually simple as its technologically complex. Traditional machine learning programs relied on a centralized model for training in which a group of servers runs a specific model against training and validation datasets. That centralized training approach can work very efficiently in many scenarios. Still, it has also proven to be challenging in use cases involving a large number of endpoints using and improving the model. The prototypical example of the limitation of the centralized training model can be found in mobile or internet of things(IoT) scenarios. The quality of a model depends on the information processed across hundreds of thousands or millions of devices. Each endpoint can contribute to a machine learning model's training in its own autonomous way in those scenarios. In other words, knowledge is federated.

Blockchain: large, distributed dataset, where no-one can edit/delete an old entry, nor fake a new entry. The data is enforced by fundamental limits of computations (i.e., Proof of Work).

Smart Contract: dataset stored on the Blockchain, which includes: Data (i.e., ledgers, events, statistics), State (today's ledger, today's events), Code (rules for changing state).

Math Examples

Fundamental Theorem of Calculus Let f:[a,b]Rf:[a,b] \to \R be Riemann integrable. Let F:[a,b]RF:[a,b]\to\R be F(x)=axf(t)dtF(x)= \int_{a}^{x}f(t)dt. Then FF is continuous, and at all xx such that ff is continuous at xx, FF is differentiable at xx with F(x)=f(x)F'(x)=f(x).

Lift(LL) can be determined by Lift Coefficient (CLC_L) like the following equation.

L=12ρv2SCLL = \frac{1}{2} \rho v^2 S C_L


[1] J.Rodriguez. Whats New in Deep Learning Research: Understanding Federated Learning.

Graph-based Person Signature for Person Re-Identifications


Person re-identification (ReID) aims to retrieve a particular person image in a collection of images captured by multiple cameras from various viewpoints across time.

The challenges of the person ReID task come from significant variations of human attributes such as poses, gaits, clothes, as well as challenging environmental settings like illumination, complex background, and occlusions. With the rise of deep learning, most of the recent studies utilize Convolutional Neural Network (CNN) to tackle the person ReID problem.

Recently, attribute-based methods have shown great success in providing semantic features for the deep network. Unlike the person identity label, which offers only coarse information to identify one identity among all other person identities, the attributes are the detailed descriptions that are highly intuitive and mostly unchanged between images captured from different cameras. Therefore, they can be used to explicitly guide the model to learn a robust person representation by defining human characteristics.

In this work, we propose to utilize the person attribute information with its associated body part to encode the visual person signature in one unified framework.

  • We hypothesize that the detailed person descriptions (attributes labels) can be integrated with visual features (body parts and global features) to create a unique signature for a particular person.
  • Since both body parts and attributes provide local representations, by linking them together, the network can have a better understanding of the relationship between visual features and attribute descriptions.
  • Although previous works have investigated how person identity, body parts, and attributes benefit the task of person ReID, our key difference is that we utilize Graph Convolutional Networks (GCN) to effectively construct and model the correlation between attributes and body parts with global features. In particular, we treat body part regions and attributes as nodes in a graph and utilize a GCN to learn the topological structure of a person's signatures. The GCN propagates messages on a graph structure. After message traversal on the graph, the node's final representations are obtained from its data and from other node's information. Figure 1 shows the effectiveness of our approach.


Figure 1: The effectiveness of our GPS in improving retrieval results on Market-1501 dataset. The details are presented in the text (Source).



Figure 2: Illustration of our proposed framework including two branches: (1) global branch which extracts person global features; (2) GPS branch which performs reasoning the person attributes and body parts using GCN. The details are presented in the text (Source).

The proposed framework is presented in Figure 2.

  • We denote II is a probe person image. This probe image II is first passed through a backbone CNN to get the feature map F\mathbf{F}.
  • By utilizing a human parsing pretrained model, we extract the body part masks to obtain the visual features of each part.
  • The person attributes are then represented by a lookup word embedding.
  • Given body part features and attribute features, we construct the Graph-based Person Signature which includes attribute nodes and body part nodes conditioned on the correlation matrix. We employ the GCN for reasoning on the person signature graph and encoding the graph into more representativeness features.

Our proposed method is a multi-branch multi-task framework for person ReID, where the main branch performs the verification task by optimizing two well-known loss functions: Triplet loss and Center loss. The auxiliary branch performs reasoning on the proposed person signature graph and solves the attribute recognition as well as the person identity classification tasks.

Experimental Results

Ablation Study

Loss Contribution. In Table 2, we show the contribution of each loss to the final performance on the Market1501 dataset. The person ID classification loss, triplet loss, center loss, and attribute recognition loss are denoted as Lid\mathbf{\mathcal{L}_{id}}, Ltriplet\mathbf{\mathcal{L}_{triplet}}, Lcenter\mathbf{\mathcal{L}_{center}}, and La\mathbf{\mathcal{L}_{a}}, respectively. The performance is improved when we incorporate all losses to the framework, which justifies the effectiveness of our proposed method. By using only Lid\mathbf{\mathcal{L}_{id}}, we still achieve comparative results with other mask-guided and attribute-based methods. While the triplet loss Ltriplet\mathbf{\mathcal{L}_{triplet}} demonstrates its capability on improving the performance, the center loss Lcenter\mathbf{\mathcal{L}_{center}} shows a slight impact on the performance. Notably, the attribute loss La\mathbf{\mathcal{L}_{a}} shows stability when being incorporated with other loss functions.


Table 2: The contribution of losses to the performance of person ReID task on Market1501 dataset. Note that the experiments are conducted with ResNet-50 as backbone CNN network (Source).

Model Interpretability. In this section, we conduct cross-dataset experiments to evaluate the effectiveness of GPS. The model is trained on the source dataset and test directly on the target dataset without finetuning. As shown in Table 3, our GPS archives a significant improvement over the Bag-of-Tricks baseline (BoT). This demonstrates the interpretability of our proposed method as well as confirms the effectiveness of learning the attributes for the person ReID task.


Table 3: The transferable ability of our GPS evaluated on cross-dataset (Source).

Training Parameters. We also provide the number of training parameters of our GPS and the baseline BoT in Table 4 to show the complexity of each method. Overall, our GPS slightly increases about 3M parameters in comparison with the baseline BoT while achieving much better performance.


Table 4: The number of parameters of our GPS in comparision with the baseline BoT on Market1501 and DukeMTMC-ReID datasets using ResNet-50 as the backbone network. #nParam indicates the number of parameters and 1K=1000 (Source).

Comparison to the state of the art

GPS vs. Baseline. The last two rows of Table 5 show the result of our GPS when being integrated into the baseline BoT. The results clearly show that our GPS significantly improves the performance of BoT in both Market-1501 and DukeMTMC-ReID dataset. This demonstrates the effectiveness of our GPS and confirms the usefulness of learning the attributes in the ReID task.


Table 5: Comparison with state-of-the-art methods on Market-1501 and DukeMTMC-ReID datasets. The cyan and yellow boxes are the best results corresponding to mask-guided/attribute-based and other approaches, respectively. Note that no post-processing is applied to our method (Source).

Evaluation on Market-1501. We evaluate our GPS with other methods on Market-1501 dataset in Table 5. The results show that our method outperforms the state-of-the-art attribute-based methods AANet that use attribute and body part information in all evaluation metrics. Specifically, we outperforms AANet by 5.3% and 1.3% at mAP and R-1, respectively. Our GPS also outperforms the state-of-the-art mask-guided methods, and especially, we outperform P2^2-Net by 2.2% at mAP. At the same time, we also get comparative results when comparing with other recent ReID approaches.

Evaluation on DukeMTMC-ReID. Table 5 also summaries the results of our GPS and other methods on DukeMTMC-ReID dataset. Our GPS significantly outperforms other attribute-based methods in all metrics. Specifically, our method outperforms the recent state-of-the-art attribute-based method AANet by 6.1% at mAP and 1.8% at R-1. In addition, we also outperforms ADPR by 9.0%, 3.9%, 2.8%, 2.0% at mAP, R-1, R-5, R-10, respectively. Moreover, our GPS outperforms the state-of-the-art mask-guided method P2^2-Net by 4.9%, 1.7%, 2.1%, 1.7% at mAP, R-1, R-5, R-10, respectively. Besides, we also achieve comparative results with other ReID approaches.

Attributed-based and Mask-guided vs. Other approaches. From Table 5, we notice that although our GPS shows a definite improvement over mask-guided and attributed-based methods, it achieves competitive results with methods from other approaches and particularly being outperformed by st-ReID method. Note that the results of st-ReID also completely dominate all methods from all other approaches. The effectiveness of st-ReID comes from the fact that it also uses the spatial-temporal information (i.e., the spatial map of camera setting and temporal information from video timestamp) into the network. This extra information allows the network to encode the person identity from multiple viewpoints, which significantly reduces the effect of different poses, viewpoints, or ambiguity challenges. From experiments, we have observed that our GPS, as well as other attribute-based and mask-guided methods, suffers from the fact that the pretrained body part network cannot provide adequate segmentation masks, so the retrieval results are also affected.

Quantity Visualization


Figure 3: Top 5 retrieval results of some queries on Market-1501 dataset. Note that the green/red boxes denote true/false retrieval results, respectively.

We present some retrieval examples with five retrieved images for each query in Figure 3. As in the visualization, our GPS obtained better retrieval results than the baseline. In the first row of Figure 3, the baseline gets the false retrieval result at Rank-5 due to the similarity of gender, wearing a hat, etc., except the color of the clothes. By leveraging our GPS, the extracted features are more robust to attribute and body part information, then, lead to better retrieval results for ReID model. In the second row, the model with our GPS gives better results by extracting more information about the relationship between `backpack' attribute and this person identity, thereby eliminating false cases. We also show an example that our GPS does not yet produce entirely correct retrieval results in the third line of the Figure 3. In this case, the lower body of the probe image is partly covered by the bicycle. Thus, the extracted features (i.e., the color of the pants) are not fully captured, which results in the feature misalignment between the probe image and retrieval results.


This paper proposes Graph-based Person Signature (GPS) that effectively captures the dependencies of person attributes and body parts information. We utilize the GCN on the GPS to propagate the information among nodes in the graph and integrate the graph features into a novel multi-branch multi-task network. The experimental results on benchmark datasets confirm the effectiveness of our GPS and demonstrate that our GPS performs better than recent state-of-the-art attribute-based and mask-guided ReID methods.

Open Source

🐱 Github: https://github.com/aioz-ai/CVPRW21_GPS

Compact Trilinear Interaction for Visual Question Answering

In Visual Question Answering (VQA), answers have a great correlation with question meaning and visual contents. Thus, to selectively utilize image, question and answer information, we propose a novel trilinear interaction model which simultaneously learns high level associations between these three inputs. In addition, to overcome the interaction complexity, we introduce a multimodal tensor-based PARALIND decomposition which efficiently parameterizes trilinear interaction between the three inputs. Moreover, knowledge distillation is applied in Free-form Opened-ended VQA. It is not only for reducing the computational cost and required memory but also for transferring knowledge from trilinear interactionmodel to bilinear interaction model. The extensive experiments on benchmarking datasets TDIUC, VQA-2.0, and Visual7W show that the proposed compact trilinear interaction model achieves state-of-the-art results on all three datasets.

For free-form opened-ended VQA task, CTI achieved 67.4 on VQA-2.0 and 87.0 on TDIUC dataset in VQA accuracy metric.

For multiple choice VQA task, CTI achieved 72.3 on Visual7W dataset in MC-VQA accuracy metric.

Compact Trilinear Interaction in VQA.

Let M={M1,M2,M3}M = \{M_1, M_2, M_3\} be the representations of three inputs. MtRnt×dtM_t \in \textbf{R}^{n_t \times d_t}, where ntn_t is the number of channels of the input MtM_t and dtd_t is the dimension of each channel.
For example, if M1M_1 is the region-based representation for an image, then n1n_1 is the number of regions and d1d_1 is the dimension of the feature representation for each region. Let mteR1×dtm_{t_e} \in \textbf{R}^{1 \times d_{t}} be the ethe^{th} row of MtM_t, i.e., the feature representation of ethe^{th} channel in MtM_t, where t{1,2,3}t \in \{1, 2, 3\}.

The input for training VQA is set of (V,Q,A)(V,Q,A) in which VV is an image representation; VRv×dvV \in \textbf{R}^{v \times d_v} where vv is the number of interested regions (or bounding boxes) in the image and dvd_v is the dimension of the representation for a region; QQ is a question representation; QRq×dqQ \in \textbf{R}^{q \times d_q }
where qq is the number of hidden states and dqd_q is the dimension for each hidden state.
AA is an answer representation; ARa×daA \in \textbf{R}^{a \times d_a}
where aa is the number of hidden states and dad_a is the dimension for each hidden state.

We firstly compute the attention map M\mathcal{M} as follows:

M=r=1RGr;VWvr,QWqr,AWar\mathcal{M} = \sum^R_{r=1} {\llbracket \mathcal{G}_r; V W_{v_r}, Q W_{q_r}, A W_{a_r}\rrbracket}

Then the joint representation zz is computed as follows:

zT=i=1vj=1qk=1aMijk(ViWzvQjWzqAkWza)z^T= \sum_{i=1}^{v}\sum_{j=1}^{q}\sum_{k=1}^{a} \mathcal{M}_{ijk}\left( V_{i}W_{z_v} \circ Q_{j}W_{z_q} \circ A_{k}W_{z_a}\right)

where Wvr,Wqr,WarW_{v_r},W_{q_r}, W_{a_r} and Wzv,Wzq,WzaW_{z_v},W_{z_q}, W_{z_a} are learnable factor matrices; each Gr\mathcal{G}_r is a learnable Tucker tensor.

Integrate CTI into different VQA task

For multiple choice VQA


Figure 1. The model when CTI is applied to MC VQA.

Each input question and each answer are trimmed to a maximum of 12 words which will then be zero-padded if shorter than 12 words. Each word is then represented by a 300-D GloVe word embedding. Each image is represented by a 14×14×204814 \times 14 \times 2048 grid feature (i.e., 196196 cells; each cell is with a 20482048-D feature), extracted from the second last layer of ResNet-152 which is pre-trained on ImageNet.

Input samples are divided into positive samples and negative samples. A positive sample, which is labelled as 11 in binary classification, contains image, question and the right answer. A negative sample, which is labelled as 00 in binary classification, contains image, question, and the wrong answer. These samples are then passed through CTI to get the joint representation zz. The joint representation is passed through a binary classifier to get the prediction. The Binary Cross Entropy loss is used for training the model.

For free-form opened-ended VQA


Figure 2. The model when CTI is applied to FFOE VQA.

Unlike MC VQA, FFOE VQA treats the answering as a classification problem over the set of predefined answers. Hence the set possible answers for each question-image pair is much more than the case of MC VQA. For each question-image input, the model takes every possible answers from its answer list to computed the joint representation, causes high computational cost.

In addition, CTI requires all three V,Q,AV, Q, A inputs to compute the joint representation. However, during the testing, there are no available answer information in FFOE VQA. To overcome these challenges, we propose to use Knowledge Distillation to transfer the learned knowledge from a teacher model to a student model.

The loss function for the student model is defined as:

LKD=αT2LCE(QSτ,QTτ)+(1α)LCE(QS,ytrue)\mathcal{L}_{KD} = \alpha T^2 \mathcal{L}_{CE}(Q^\tau_S, Q^\tau_T) + (1-\alpha)\mathcal{L}_{CE}(Q_S,y_{true})

where LCE\mathcal{L}_{CE} stands for Cross Entropy loss; QSQ_S is the standard softmax output of the student; ytruey_{true} is the ground-truth answer labels;
α\alpha is a hyper-parameter for controlling the importance of each loss component; QSτ,QTτQ^\tau_S, Q^\tau_T are the softened outputs of the student and the teacher using the same temperature parameter TT, which are computed as follows:

Qiτ=exp(li/T)iexp(li/T)Q^\tau_i = \frac{exp(l_i/T)}{\sum_{i} exp(l_i/T)}

where for both teacher and the student models, the logit ll is the predictions outputted by the corresponding classifiers.



Table 1. Performance of CTI and BAN2, SAN in VQA-2.0 validation set and test-dev set. BAN2-CTI and SANCTI are student models trained under the teacher model.

To further evaluate the effectiveness of CTI, we conduct a detailed comparison with the current state of the art. For FFOE VQA, we compare CTI with the recent state-of-the-art methods on TDIUC and VQA-2.0 datasets. For MC VQA, we compare with the state-of-the-art methods on Visual7W dataset.


Table 2. Performance comparison between different approaches with different evaluation metrics on TDIUC validation set. BAN2-CTI and SAN-CTI are the student models trained under compact trilinear interaction teacher model.

Regarding FFOE VQA, Table 1 and Table 2 show comparati