Commit f3125b1d authored by Tomas Rokos's avatar Tomas Rokos

Add vgg16 to the project

parent e0ffd75f
......@@ -22,6 +22,7 @@ if not os.path.exists(UPLOAD_FOLDER):
simple_dnn_classifier = DNNSimpleClassifierFactory.create()
resnet_classifier = ResNetClassifierFactory.create()
dtw_classifier = DTWClassifierFactory.create()
vgg_classifier = ResNetClassifierFactory.create()
resnet_similarity_knn = ResNetSimilarityKnnFactory.create()
......@@ -68,6 +69,10 @@ def classify():
'name': 'ResNet-18 Classifier',
'genre': resnet_classifier.classify_mfcc(mfcc)
},
{
'name': 'VGG-16 Classifier',
'genre': vgg_classifier.classify_mfcc(mfcc)
},
{
'name': 'ResNet-18 Embedding + KNN Classifier',
'genre': resnet_similarity_knn.classify_mfcc(mfcc)
......
import os
import numpy as np
from keras import models
from audio_classification.classifier.base_classifier import BaseClassifier
from audio_classification.util.transform import decode_genre
class VggClassifier(BaseClassifier):
RESNET_MODEL_PATH = os.path.join('models', 'vgg_classifier', 'model.h5')
def __init__(self, model_path, **kwargs):
super().__init__(**kwargs)
self.model = models.load_model(model_path)
def classify(self, mfcc):
mfcc = np.expand_dims(mfcc, axis=0)
mfcc = np.expand_dims(mfcc, axis=3)
return decode_genre(self.model.predict(mfcc).argmax(axis=1)[0])
class ResNetClassifierFactory:
@staticmethod
def create():
return VggClassifier(VggClassifier.RESNET_MODEL_PATH)
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment