Commit c7c99f20 authored by Tomas Rokos's avatar Tomas Rokos
parents 84ee67e7 b469586b
This diff is collapsed.
from keras import models, layers
import pandas as pd
import numpy as np
import os
import numpy as np
from keras import models
from audio_classification.classifier.base_classifier import BaseClassifier
from audio_classification.util.transform import decode_genre
class DNNSimpleClassifier(BaseClassifier):
MODEL_PATH = os.path.join('models', 'simple_dnn', 'model.h5')
SIMPLE_MODEL_PATH = os.path.join('models', 'simple_dnn', 'model.h5')
def __init__(self, **kwargs):
def __init__(self, model_path, **kwargs):
super().__init__(**kwargs)
self.model = model = models.Sequential()
model.add(layers.Dense(1024, activation='relu', input_shape=(13,)))
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
def load(self):
self.model = models.load_model(self.MODEL_PATH)
self.model = models.load_model(model_path)
def classify(self, mfcc):
mfcc = np.expand_dims(mfcc, axis=0)
......@@ -36,7 +25,4 @@ class DNNSimpleClassifier(BaseClassifier):
class DNNSimpleClassifierFactory:
@staticmethod
def create():
cls = DNNSimpleClassifier()
cls.load()
return cls
return DNNSimpleClassifier(DNNSimpleClassifier.SIMPLE_MODEL_PATH)
import os
import numpy as np
from keras import models
from audio_classification.classifier.base_classifier import BaseClassifier
from audio_classification.util.transform import decode_genre
class ResNetClassifier(BaseClassifier):
RESNET_MODEL_PATH = os.path.join('models', 'resnet_classifier', 'model.h5')
def __init__(self, model_path, **kwargs):
super().__init__(**kwargs)
self.model = models.load_model(model_path)
def classify(self, mfcc):
mfcc = np.expand_dims(mfcc, axis=0)
mfcc = np.expand_dims(mfcc, axis=3)
return decode_genre(self.model.predict(mfcc).argmax(axis=1)[0])
class ResNetClassifierFactory:
@staticmethod
def create():
return ResNetClassifier(ResNetClassifier.RESNET_MODEL_PATH)
import os
from abc import ABC
import tensorflow as tf
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from audio_classification.classifier.base_classifier import BaseClassifier
from audio_classification.util.transform import decode_genre
MODEL_FOLDER = os.path.join('models', 'siamese')
MODEL_PATH = os.path.join(MODEL_FOLDER, 'model.h5')
KNN_SAMPLES_PATH = os.path.join(MODEL_FOLDER, 'data.npy')
KNN_LABELS_PATH = os.path.join(MODEL_FOLDER, 'labels.npy')
class ResNetSimilarityKnn(BaseClassifier):
def __init__(self, **kwargs):
super().__init__(**kwargs)
x = np.load(KNN_SAMPLES_PATH)
y = np.load(KNN_LABELS_PATH)
self.knn = KNeighborsClassifier(n_neighbors=12)
self.knn.fit(x, y)
self.embedding_model = tf.keras.models.load_model(MODEL_PATH, compile=False)
def classify(self, mfcc):
mfcc = np.expand_dims(mfcc, axis=0)
mfcc = np.expand_dims(mfcc, axis=3)
feature_vector = self.embedding_model.predict(mfcc)
prediction = self.knn.predict(feature_vector)
return decode_genre(prediction[0])
class ResNetSimilarityKnnFactory:
@staticmethod
def create():
return ResNetSimilarityKnn()
......@@ -7,6 +7,8 @@ from werkzeug.utils import redirect, secure_filename
from audio_classification.dnn_simple_classifier.dnn_simple_classifier import DNNSimpleClassifierFactory
from audio_classification.dtwclassifier.dtwclassifier import DTWClassifierFactory
from audio_classification.resnet_classifier.resnet_classifier import ResNetClassifierFactory
from audio_classification.resnet_similarity_knn.nn_similarity import ResNetSimilarityKnnFactory
from audio_classification.run.utils import generate_soundwave
UPLOAD_FOLDER = 'upload'
......@@ -17,6 +19,12 @@ if not os.path.exists(UPLOAD_FOLDER):
os.mkdir(UPLOAD_FOLDER)
simple_dnn_classifier = DNNSimpleClassifierFactory.create()
resnet_classifier = ResNetClassifierFactory.create()
dtw_classifier = DTWClassifierFactory.create()
resnet_similarity_knn = ResNetSimilarityKnnFactory.create()
@app.route('/')
def index():
return send_from_directory('static', 'index.html')
......@@ -49,8 +57,6 @@ def classify():
file.save(filepath)
audio_file, sr = librosa.load(filepath)
mfcc = librosa.feature.mfcc(y=audio_file, sr=sr, n_mfcc=13)
simple_dnn_classifier = DNNSimpleClassifierFactory.create()
dtw_classifier = DTWClassifierFactory.create()
return jsonify({
"classification": [
......@@ -58,6 +64,14 @@ def classify():
'name': 'Simple DNN Classifier',
'genre': simple_dnn_classifier.classify_mfcc(mfcc)
},
{
'name': 'ResNet-18 Classifier',
'genre': resnet_classifier.classify_mfcc(mfcc)
},
{
'name': 'ResNet-18 Embedding + KNN Classifier',
'genre': resnet_similarity_knn.classify_mfcc(mfcc)
},
{
'name': 'DTW classifier',
'genre': dtw_classifier.classify_mfcc(mfcc)
......
......@@ -5,19 +5,24 @@ argon2-cffi==20.1.0
astunparse==1.6.3
async-generator==1.10
attrs==20.2.0
-e git+https://gitlab.fit.cvut.cz/rudolja4/ni-vmm-music-genre-classification.git@15f0a7e42fcc605c8c8262ccbfa3423edc21130f#egg=audio_classification
-e git+https://gitlab.fit.cvut.cz/rudolja4/ni-vmm-music-genre-classification.git@d7dd7a87da83dc20a4972492c999c5d48c6659f8#egg=audio_classification
audioread==2.1.9
backcall==0.2.0
bleach==3.2.1
bokeh==2.2.3
cached-property==1.5.2
cachetools==4.1.1
certifi==2020.6.20
cffi==1.14.3
chardet==3.0.4
click==7.1.2
cloudpickle==1.6.0
colorcet==2.0.2
cycler==0.10.0
datashape==0.5.2
decorator==4.4.2
defusedxml==0.6.0
dtw==1.4.0
entrypoints==0.3
Flask==1.1.2
gast==0.3.3
......@@ -26,6 +31,8 @@ google-auth-oauthlib==0.4.2
google-pasta==0.2.0
grpcio==1.33.2
h5py==2.10.0
HeapDict==1.0.1
holoviews==1.14.0
idna==2.10
importlib-metadata==2.0.0
ipykernel==5.3.4
......@@ -44,10 +51,13 @@ Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
librosa==0.8.0
llvmlite==0.34.0
locket==0.2.0
Markdown==3.3.3
MarkupSafe==1.1.1
matplotlib==3.3.2
mistune==0.8.4
msgpack==1.0.0
multipledispatch==0.6.0
nbclient==0.5.1
nbconvert==6.0.7
nbformat==5.0.8
......@@ -60,7 +70,10 @@ opt-einsum==3.3.0
packaging==20.4
pandas==1.1.4
pandocfilters==1.4.3
panel==0.10.2
param==1.10.0
parso==0.7.1
partd==1.1.0
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.0.1
......@@ -72,11 +85,13 @@ ptyprocess==0.6.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
pyct==0.4.8
Pygments==2.7.2
pyparsing==2.4.7
pyrsistent==0.17.3
python-dateutil==2.8.1
pytz==2020.4
pyviz-comms==0.7.6
PyYAML==5.3.1
pyzmq==19.0.2
requests==2.24.0
......@@ -88,20 +103,30 @@ scipy==1.5.3
seaborn==0.11.0
Send2Trash==1.5.0
six==1.15.0
sortedcontainers==2.3.0
SoundFile==0.10.3.post1
tblib==1.7.0
tensorboard==2.4.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.1
tensorflow-addons==0.11.2
tensorflow-estimator==2.3.0
termcolor==1.1.0
terminado==0.9.1
testpath==0.4.4
threadpoolctl==2.1.0
toolz==0.11.1
tornado==6.1
tqdm==4.54.1
traitlets==5.0.5
typeguard==2.10.0
typing-extensions==3.7.4.3
umap-learn==0.4.6
urllib3==1.25.11
wcwidth==0.2.5
webencodings==0.5.1
Werkzeug==1.0.1
wrapt==1.12.1
xarray==0.16.2
zict==2.0.0
zipp==3.4.0
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment