Commit 355cf8f8 authored by Tomas Rokos's avatar Tomas Rokos
Browse files

Fix the new preprocessing for mfccs

parent 4245c7ce
from audio_classification.classifier.base_classifier import BaseClassifier
import numpy as np
from dtw import accelerated_dtw
import pandas as pd
from audio_classification.dtwclassifier.distances import euclid
from audio_classification.dtwclassifier.dtw import dtw
from audio_classification.dtwclassifier.ranges import default_range
def normalize_matrices(A, B):
norm_shape = min(A.shape[1], B.shape[1])
return A[:, :norm_shape], B[:, :norm_shape]
def compare_mfccs(A, B, dist):
A, B = normalize_matrices(A, B)
res, _, _, _ = accelerated_dtw(A, B, dist)
return res
def aggregate_genre(grpdf, testdf, agg_func, n_compared_songs, dist):
first_ten = grpdf.head(n_compared_songs)
mfccs = first_ten['file'].apply(lambda x: compare_mfccs(np.load(x), testdf, dist))
return agg_func(mfccs, first_ten)
class DTWClassifier(BaseClassifier):
def __init__(self, mode=MODE_MEAN, n_compared_songs=10, dist='euclidean'):
def __init__(self, mode=MODE_MEAN, n_compared_songs=10, dist_fn=euclid, range_fn=default_range):
self.mode = mode
self.n_compared_songs = n_compared_songs
self.dist = dist
self.dist_fn = dist_fn
self.range_fn = range_fn
self.traindf = None
def __get_mfcc_agg_func(self):
......@@ -38,6 +30,16 @@ class DTWClassifier(BaseClassifier):
return lambda mfccs, grouped_df: min(mfccs)
raise Exception("Bad mode specified.")
def __compare_mfccs(self, A, B):
A, B = normalize_matrices(A, B)
res, _ = dtw(A, B, self.dist_fn, self.range_fn)
return res
def __aggregate_genre(self, grpdf, testdf, agg_func, n_compared_songs):
first_ten = grpdf.head(n_compared_songs)
mfccs = first_ten['file'].apply(lambda x: self.__compare_mfccs(np.load(x), testdf))
return agg_func(mfccs, first_ten)
def classify(self, mfcc):
return self.__classify_mfcc(mfcc)
......@@ -47,7 +49,7 @@ class DTWClassifier(BaseClassifier):
def __classify_mfcc(self, mfcc):
result = self.traindf.groupby('label').apply(
lambda x: aggregate_genre(x, mfcc, self.__get_mfcc_agg_func(), self.n_compared_songs, self.dist)
lambda x: self.__aggregate_genre(x, mfcc, self.__get_mfcc_agg_func(), self.n_compared_songs)
return result.idxmin()
......@@ -59,6 +61,6 @@ class DTWClassifierFactory:
def create():
cls = DTWClassifier()
df = pd.read_csv('preprocessed_nmfcc_13/data.csv', header=None, names=['file', 'label'])
df = pd.read_csv('preprocessed_nmfcc_13/data.csv')['file'], df['label'])
return cls
......@@ -79,4 +79,3 @@ def classify():
if __name__ == '__main__':
......@@ -2,7 +2,6 @@ import base64
import io
import librosa
import librosa.display
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment