Set Representation Learning with Generalized Sliced-Wasserstein Embeddings
Abstract
An increasing number of machine learning tasks deal with learning representations from set-structured data. Solutions to these problems involve the composition of permutation-equivariant modules (e.g., self-attention, or individual processing via feed-forward neural networks) and permutation-invariant modules (e.g., global average pooling, or pooling by multi-head attention). In this paper, we propose a geometrically-interpretable framework for learning representations from set-structured data, which is rooted in the optimal mass transportation problem. In particular, we treat elements of a set as samples from a probability measure and propose an exact Euclidean embedding for Generalized Sliced Wasserstein (GSW) distances to learn from set-structured data effectively. We evaluate our proposed framework on multiple supervised and unsupervised set learning tasks and demonstrate its superiority over state-of-the-art set representation learning approaches.
- Publication:
-
arXiv e-prints
- Pub Date:
- March 2021
- DOI:
- 10.48550/arXiv.2103.03892
- arXiv:
- arXiv:2103.03892
- Bibcode:
- 2021arXiv210303892N
- Keywords:
-
- Computer Science - Machine Learning