Commit d0bd0c32 authored by Tomas Rokos's avatar Tomas Rokos

Improve the DTW classifier, try out basic DNN classifier

parent 5208df95
......@@ -2,11 +2,12 @@ venv
data
.idea
genres.tar.gz
*/.ipynb_checkpoints
.ipynb_checkpoints/*
input
upload
preprocessed_*
audio_classification.egg-info
*.cpython-37.pyc
node_modules
.cache
\ No newline at end of file
.cache
graphs/*
This diff is collapsed.
import numpy as np
import math
def euclid(a, b):
return math.sqrt((a-b) ** 2)
def abs_sub(a, b):
return abs(a-b)
def dtw(v1, v2, dist_fn=euclid):
v1len, v2len = len(v1), len(v2)
matrix_rows, matrix_cols = v1len + 1, v2len + 1
matrix = np.full((matrix_rows, matrix_cols), np.inf)
matrix[0, 0] = 0
for i in range(1, matrix_rows):
for j in range(1, matrix_cols):
cost = dist_fn(v1[i - 1], v2[j - 1])
min_three = np.min([matrix[i, j - 1], matrix[i - 1, j], matrix[i - 1, j - 1]])
matrix[i, j] = cost + min_three
return matrix[v1len, v2len]
from audio_classification.classifier.base_classifier import BaseClassifier
import numpy as np
from dtw import accelerated_dtw
import pandas as pd
def normalize_matrices(A, B):
norm_shape = min(A.shape[1], B.shape[1])
return A[:, :norm_shape], B[:, :norm_shape]
def compare_mfccs(A, B, dist):
A, B = normalize_matrices(A, B)
res, _, _, _ = accelerated_dtw(A, B, dist)
return res
def aggregate_genre(grpdf, testdf, agg_func, n_compared_songs, dist):
first_ten = grpdf.head(n_compared_songs)
mfccs = first_ten['file'].apply(lambda x: compare_mfccs(np.load(x), testdf, dist))
return agg_func(mfccs, first_ten)
class DTWClassifier(BaseClassifier):
MODE_MEAN = 'MEAN'
MODE_MIN = 'MIN'
def __init__(self, traindf, mode=MODE_MEAN, n_compared_songs=10, dist='euclidean'):
self.traindf = traindf
self.mode = mode
self.n_compared_songs = n_compared_songs
self.dist = dist
def __get_mfcc_agg_func(self):
if self.mode == self.MODE_MEAN:
return lambda mfccs, grouped_df: sum(mfccs) / len(grouped_df)
if self.mode == self.MODE_MIN:
return lambda mfccs, grouped_df: min(mfccs)
raise Exception("Bad mode specified.")
def classify(self, testdf):
result = testdf.iloc[:, 0].apply(lambda x: self.classify_mfcc(np.load(x)))
return pd.DataFrame(result)
def classify_mfcc(self, mfcc):
result = self.traindf.groupby('label').apply(
lambda x: aggregate_genre(x, mfcc, self.__get_mfcc_agg_func(), self.n_compared_songs, self.dist)
)
return result.idxmin()
\ No newline at end of file
......@@ -13,8 +13,8 @@ DUMB_PREFIX = '/Users/sness/mirex2008/genres/'
def preprocess_sample(path, **kwargs):
audio_file, sr = librosa.load(path)
y, _ = librosa.effects.trim(audio_file)
return librosa.feature.mfcc(y=y, sr=sr, **kwargs)
# y, _ = librosa.effects.trim(audio_file)
return librosa.feature.mfcc(y=audio_file, sr=sr, **kwargs)
def preprocess(n_mfcc):
......@@ -61,4 +61,4 @@ def preprocess(n_mfcc):
if __name__ == '__main__':
preprocess(20)
preprocess(13)
absl-py==0.11.0
appdirs==1.4.4
appnope==0.1.0
argon2-cffi==20.1.0
astunparse==1.6.3
async-generator==1.10
attrs==20.2.0
-e git+https://gitlab.fit.cvut.cz/rudolja4/ni-vmm-music-genre-classification.git@5208df95d488bb4a7806a242606bbef486bd3b9e#egg=audio_classification
audioread==2.1.9
backcall==0.2.0
bleach==3.2.1
cached-property==1.5.2
cachetools==4.1.1
certifi==2020.6.20
cffi==1.14.3
chardet==3.0.4
......@@ -13,8 +18,15 @@ click==7.1.2
cycler==0.10.0
decorator==4.4.2
defusedxml==0.6.0
dtw==1.4.0
entrypoints==0.3
Flask==1.1.2
gast==0.3.3
google-auth==1.23.0
google-auth-oauthlib==0.4.2
google-pasta==0.2.0
grpcio==1.33.2
h5py==2.10.0
idna==2.10
importlib-metadata==2.0.0
ipykernel==5.3.4
......@@ -28,9 +40,12 @@ jsonschema==3.2.0
jupyter-client==6.1.7
jupyter-core==4.6.3
jupyterlab-pygments==0.1.2
Keras==2.4.3
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
librosa==0.8.0
llvmlite==0.34.0
Markdown==3.3.3
MarkupSafe==1.1.1
matplotlib==3.3.2
mistune==0.8.4
......@@ -40,7 +55,9 @@ nbformat==5.0.8
nest-asyncio==1.4.2
notebook==6.1.4
numba==0.51.2
numpy==1.19.4
numpy==1.18.5
oauthlib==3.1.0
opt-einsum==3.3.0
packaging==20.4
pandas==1.1.4
pandocfilters==1.4.3
......@@ -51,21 +68,33 @@ Pillow==8.0.1
pooch==1.2.0
prometheus-client==0.8.0
prompt-toolkit==3.0.8
protobuf==3.14.0
ptyprocess==0.6.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
Pygments==2.7.2
pyparsing==2.4.7
pyrsistent==0.17.3
python-dateutil==2.8.1
pytz==2020.4
PyYAML==5.3.1
pyzmq==19.0.2
requests==2.24.0
requests-oauthlib==1.3.0
resampy==0.2.2
rsa==4.6
scikit-learn==0.23.2
scipy==1.5.3
seaborn==0.11.0
Send2Trash==1.5.0
six==1.15.0
SoundFile==0.10.3.post1
tensorboard==2.4.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.1
tensorflow-estimator==2.3.0
termcolor==1.1.0
terminado==0.9.1
testpath==0.4.4
threadpoolctl==2.1.0
......@@ -75,4 +104,5 @@ urllib3==1.25.11
wcwidth==0.2.5
webencodings==0.5.1
Werkzeug==1.0.1
wrapt==1.12.1
zipp==3.4.0
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment