So far, I have read some highly cited metric learning papers. The general idea of such papers is to learn a mapping such that mapped data points with same label lie close to each other and far from samples of other classes. To evaluate such techniques they report the accuracy of the KNN classifier on the generated embedding. So my question is if we have a labelled dataset and we are interested in increasing the accuracy of classification task, why do not we learn a classifier on the original datapoints. I mean instead of finding a new embedding which suites KNN classifier, we can learn a classifier that fits the (not embedded) datapoints. Based on what I have read so far the classification accuracy of such classifiers is much better than metric learning approaches. Is there a study that shows metric learning+KNN performs better than fitting a (good) classifier at least on some datasets?
Metric learning models CAN BE classifiers. So I will answer the question that why do we need metric learning for classification.
Let me give you an example. When you have a dataset of millions of classes and some classes have only limited examples, let's say less than 5. If you use classifiers such as SVMs or normal CNNs, you will find it impossible to train because those classifiers (discriminative models) will totally ignore the classes of few examples.
But for the metric learning models, it is not a problem since they are based on generative models.
By the way, the large number of classes is a challenge for discriminative models itself.
The real-life challenge inspires us to explore more better models.
As @Tengerye mentioned, you can use models trained using metric learning for classification. KNN is the simplest approach but you can take the embeddings of your data and train another classifier, be it KNN, SVM, Neural Network, etc. The use of metric learning, in this case, would be to change the original input space to another one which would be easier for a classifier to handle.
Apart from discriminative models being hard to train when data is unbalanced, or even worse, have very few examples per class, they cannot be easily extended for new classes.
Take for example facial recognition, if facial recognition models are trained as classification models, these models would only work for the faces it has seen and wouldn't work for any new face. Of course, you could add images for the faces you wish to add and retrain the model or fine-tune the model if possible, but this is highly impractical. On the other hand, facial recognition models trained using metric learning can generate embeddings for new faces, which can be easily added to the KNN and your system then can identify the new person given his/her image.
© 2022 - 2024 — McMap. All rights reserved.