Commit c7f7cd15 authored by Jan Rudolf's avatar Jan Rudolf

ADD resnet classifier and embeding, refactor

parent 9b092467
This diff is collapsed.
from keras import models, layers
import pandas as pd
import numpy as np
import os
import numpy as np
from keras import models
from audio_classification.classifier.base_classifier import BaseClassifier
from audio_classification.util.transform import decode_genre
class DNNSimpleClassifier(BaseClassifier):
MODEL_PATH = os.path.join('models', 'simple_dnn', 'model.h5')
SIMPLE_MODEL_PATH = os.path.join('models', 'simple_dnn', 'model.h5')
def __init__(self, **kwargs):
def __init__(self, model_path, **kwargs):
super().__init__(**kwargs)
self.model = model = models.Sequential()
model.add(layers.Dense(1024, activation='relu', input_shape=(13,)))
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
def load(self):
self.model = models.load_model(self.MODEL_PATH)
self.model = models.load_model(model_path)
def classify(self, mfcc):
mfcc = np.expand_dims(mfcc, axis=0)
......@@ -36,7 +25,4 @@ class DNNSimpleClassifier(BaseClassifier):
class DNNSimpleClassifierFactory:
@staticmethod
def create():
cls = DNNSimpleClassifier()
cls.load()
return cls
return DNNSimpleClassifier(DNNSimpleClassifier.SIMPLE_MODEL_PATH)
import os
import numpy as np
from keras import models
from audio_classification.classifier.base_classifier import BaseClassifier
from audio_classification.util.transform import decode_genre
class ResNetClassifier(BaseClassifier):
RESNET_MODEL_PATH = os.path.join('models', 'resnet_classifier', 'model.h5')
def __init__(self, model_path, **kwargs):
super().__init__(**kwargs)
self.model = models.load_model(model_path)
def classify(self, mfcc):
mfcc = np.expand_dims(mfcc, axis=0)
mfcc = np.expand_dims(mfcc, axis=3)
return decode_genre(self.model.predict(mfcc).argmax(axis=1)[0])
class ResNetClassifierFactory:
@staticmethod
def create():
return ResNetClassifier(ResNetClassifier.RESNET_MODEL_PATH)
import os
from abc import ABC
import tensorflow as tf
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from audio_classification.classifier.base_classifier import BaseClassifier
from audio_classification.util.transform import decode_genre
MODEL_FOLDER = os.path.join('models', 'siamese')
MODEL_PATH = os.path.join(MODEL_FOLDER, 'model.h5')
KNN_SAMPLES_PATH = os.path.join(MODEL_FOLDER, 'data.npy')
KNN_LABELS_PATH = os.path.join(MODEL_FOLDER, 'labels.npy')
class ResNetSimilarityKnn(BaseClassifier):
def __init__(self, **kwargs):
super().__init__(**kwargs)
x = np.load(KNN_SAMPLES_PATH)
y = np.load(KNN_LABELS_PATH)
self.knn = KNeighborsClassifier(n_neighbors=12)
self.knn.fit(x, y)
self.embedding_model = tf.keras.models.load_model(MODEL_PATH, compile=False)
def classify(self, mfcc):
mfcc = np.expand_dims(mfcc, axis=0)
mfcc = np.expand_dims(mfcc, axis=3)
feature_vector = self.embedding_model.predict(mfcc)
prediction = self.knn.predict(feature_vector)
return decode_genre(prediction[0])
class ResNetSimilarityKnnFactory:
@staticmethod
def create():
return ResNetSimilarityKnn()
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment