similarity_nn_training.ipynb 123 KB
Newer Older
Jan Rudolf's avatar
Jan Rudolf committed
1 2 3 4
{
 "cells": [
  {
   "cell_type": "code",
5
   "execution_count": 24,
Jan Rudolf's avatar
Jan Rudolf committed
6 7 8 9 10 11 12
   "metadata": {},
   "outputs": [],
   "source": [
    "from audio_classification.preprocess import preprocess\n",
    "\n",
    "from sklearn import preprocessing\n",
    "from sklearn.model_selection import train_test_split\n",
13 14 15
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "import os\n",
Jan Rudolf's avatar
Jan Rudolf committed
16 17 18 19 20 21 22 23 24 25
    "import umap\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import tensorflow_addons as tfa\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import backend as K\n",
    "from tensorflow.keras import Model\n",
    "from tensorflow.keras.utils import plot_model\n",
26
    "from tensorflow.keras.layers import Input, ZeroPadding2D, Conv2D, BatchNormalization, Activation, Dense, add, MaxPool2D, Dropout, GlobalMaxPool2D"
Jan Rudolf's avatar
Jan Rudolf committed
27 28 29 30 31 32 33 34 35 36
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUMBER_OF_MFCCS = 13\n",
    "CLIP_SIZE = 1290\n",
37 38 39 40 41 42 43 44 45
    "SAMPLE_SHAPE = (NUMBER_OF_MFCCS, CLIP_SIZE, 1)\n",
    "\n",
    "MODEL_FOLDER = os.path.join('models', 'siamese')\n",
    "MODEL_PATH = os.path.join(MODEL_FOLDER, 'model.h5')\n",
    "KNN_SAMPLES_PATH = os.path.join(MODEL_FOLDER, 'data.npy')\n",
    "KNN_LABELS_PATH = os.path.join(MODEL_FOLDER, 'labels.npy')\n",
    "\n",
    "if not os.path.isdir(MODEL_FOLDER):\n",
    "    os.makedirs(MODEL_FOLDER)"
Jan Rudolf's avatar
Jan Rudolf committed
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cached features found\n"
     ]
    }
   ],
   "source": [
62 63
    "data, similarity = preprocess(NUMBER_OF_MFCCS)\n",
    "\n",
Jan Rudolf's avatar
Jan Rudolf committed
64 65 66 67 68 69 70 71 72 73
    "le = preprocessing.LabelEncoder()\n",
    "transformed = le.fit_transform(data['label'])\n",
    "l = []\n",
    "for index, row in data.iterrows():\n",
    "    arr = np.load(row['file'])\n",
    "    l.append(arr[:, :CLIP_SIZE])\n",
    "    \n",
    "X = np.expand_dims(np.stack(l), axis=3)\n",
    "y = np.array(transformed)\n",
    "\n",
74 75
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, stratify=y, random_state=666)\n",
    "X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, stratify=y_train, random_state=666)\n",
Jan Rudolf's avatar
Jan Rudolf committed
76 77
    "\n",
    "train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(32)\n",
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
    "valid_dataset = tf.data.Dataset.from_tensor_slices((X_valid, y_valid)).batch(32)\n",
    "test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.isdir(MODEL_FOLDER):\n",
    "    os.makedirs(MODEL_FOLDER)\n",
    "\n",
    "class CustomMetricCallback(keras.callbacks.Callback):\n",
    "    def __init__(self, patience, verbose = 0, save_best = True):\n",
    "        super(keras.callbacks.Callback, self).__init__()\n",
    "        self.patience = patience\n",
    "        self.verbose = verbose\n",
    "        self.save_best = save_best\n",
    "        self.best_val_score = -float('inf')\n",
    "        self.last_improvement = 0\n",
    "\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs={}):\n",
    "        train_features = np.asarray(self.model.predict(train_dataset))\n",
    "        valid_features = np.asarray(self.model.predict(valid_dataset))\n",
    "        \n",
    "        knn = KNeighborsClassifier(n_neighbors=12)\n",
    "        knn.fit(train_features, y_train)\n",
    "        \n",
    "        train_pred = knn.predict(train_features)\n",
    "        valid_pred = knn.predict(valid_features)\n",
    "        \n",
    "        train_score = accuracy_score(y_train, train_pred)\n",
    "        valid_score = accuracy_score(y_valid, valid_pred)\n",
    "        \n",
    "        if self.verbose > 0:\n",
    "            print(f', acc: {train_score}, val_acc: {valid_score}')\n",
    "            \n",
    "        if self.save_best and self.best_val_score < valid_score:\n",
    "            print(f'Best improved: {self.best_val_score} -> {valid_score}. Saving model to {MODEL_FOLDER}')\n",
    "            self.best_val_score = valid_score\n",
    "            self.last_improvement = epoch\n",
    "            self.model.save(MODEL_PATH, overwrite=True)\n",
    "            \n",
    "            np.save(KNN_SAMPLES_PATH, train_features)\n",
    "            np.save(KNN_LABELS_PATH, y_train)\n",
    "        else:\n",
    "            print(f'Score did not improve')\n",
    "            \n",
    "        if self.patience is not None and epoch - self.last_improvement > self.patience:\n",
    "            print(f'Score did not improve for {self.patience} epochs. Stopping')\n",
    "            self.model.stop_training = True"
Jan Rudolf's avatar
Jan Rudolf committed
131 132 133 134 135 136 137 138
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
139
   "outputs": [],
Jan Rudolf's avatar
Jan Rudolf committed
140 141 142 143 144
   "source": [
    "def res_block_2l(block_number, input_block, filters, strides = 1):\n",
    "    block_name = 'block' + str(block_number) + '_'\n",
    "\n",
    "    y = Conv2D(\n",
145
    "        filters=filters,\n",
Jan Rudolf's avatar
Jan Rudolf committed
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    "        kernel_size=3,\n",
    "        strides=strides, \n",
    "        padding='same', \n",
    "        name=block_name + 'conv_1'\n",
    "    )(input_block)\n",
    "    y = BatchNormalization(name=block_name + 'bn_1')(y)\n",
    "    y = Activation('relu', name=block_name + 'activation_1')(y)\n",
    "\n",
    "    y = Conv2D(\n",
    "        filters=filters, \n",
    "        kernel_size=3, \n",
    "        padding='same',\n",
    "        name=block_name + 'conv_2'\n",
    "    )(y)\n",
    "    y = BatchNormalization(name=block_name + 'bn_2')(y)\n",
    "\n",
    "    if strides != 1:\n",
    "        z = Conv2D(kernel_size=1, filters=filters, strides=strides, name=block_name + 'conv_S')(input_block)\n",
    "        z = BatchNormalization(name=block_name + 'bn_S')(z)\n",
    "    else:\n",
    "        z = input_block\n",
    "    x = add([z, y], name=block_name + 'add')\n",
    "    return Activation('relu', name=block_name + 'activation_after')(x)\n",
    "\n",
    "\n",
    "inp = Input(shape=SAMPLE_SHAPE, name='embedding_input')\n",
    "\n",
    "x = Conv2D(filters=32, kernel_size=3, padding='same', name='conv_first')(inp)\n",
    "x = BatchNormalization(name='bn_first')(x)\n",
    "x = Activation('relu')(x)\n",
    "\n",
    "x = res_block_2l(1, x, 32)\n",
    "x = res_block_2l(2, x, 64, 2)\n",
    "x = res_block_2l(3, x, 64)\n",
180 181 182
    "x = res_block_2l(4, x, 64)\n",
    "x = res_block_2l(5, x, 128, 2)\n",
    "x = res_block_2l(6, x, 128)\n",
Jan Rudolf's avatar
Jan Rudolf committed
183 184 185
    "\n",
    "x = GlobalMaxPool2D()(x)\n",
    "\n",
186 187
    "x = Dense(128, name='dense_1', activation='relu')(x)\n",
    "x = Dropout(0.1)(x)\n",
Jan Rudolf's avatar
Jan Rudolf committed
188 189 190
    "x = Dense(128, name='dense_2')(x)\n",
    "\n",
    "model = Model(inp, x, name='embedding_model')\n",
191
    "\n",
Jan Rudolf's avatar
Jan Rudolf committed
192 193
    "model.compile(\n",
    "    optimizer=tf.keras.optimizers.Nadam(),\n",
194 195 196
    "    loss=tfa.losses.TripletSemiHardLoss()\n",
    ")\n",
    "#model.summary()"
Jan Rudolf's avatar
Jan Rudolf committed
197 198 199 200 201 202
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
203
    "scrolled": true
Jan Rudolf's avatar
Jan Rudolf committed
204 205 206 207 208 209
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [

      "Epoch 1/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.3694, acc: 0.51, val_acc: 0.37\n",
      "Best improved: -inf -> 0.37. Saving model to models/siamese\n",
      "19/19 [==============================] - 16s 854ms/step - loss: 0.3694\n",
      "Epoch 2/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.2748, acc: 0.43666666666666665, val_acc: 0.275\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 776ms/step - loss: 0.2748\n",
      "Epoch 3/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.2405, acc: 0.49833333333333335, val_acc: 0.37\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 779ms/step - loss: 0.2405\n",
      "Epoch 4/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.1568, acc: 0.49833333333333335, val_acc: 0.355\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 776ms/step - loss: 0.1568\n",
      "Epoch 5/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.1772, acc: 0.59, val_acc: 0.455\n",
      "Best improved: 0.37 -> 0.455. Saving model to models/siamese\n",
      "19/19 [==============================] - 15s 791ms/step - loss: 0.1772\n",
      "Epoch 6/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.2355, acc: 0.5833333333333334, val_acc: 0.475\n",
      "Best improved: 0.455 -> 0.475. Saving model to models/siamese\n",
      "19/19 [==============================] - 15s 799ms/step - loss: 0.2355\n",
      "Epoch 7/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0972, acc: 0.6066666666666667, val_acc: 0.49\n",
      "Best improved: 0.475 -> 0.49. Saving model to models/siamese\n",
      "19/19 [==============================] - 15s 789ms/step - loss: 0.0972\n",
      "Epoch 8/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0934, acc: 0.6033333333333334, val_acc: 0.545\n",
      "Best improved: 0.49 -> 0.545. Saving model to models/siamese\n",
      "19/19 [==============================] - 15s 786ms/step - loss: 0.0934\n",
      "Epoch 9/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0860, acc: 0.63, val_acc: 0.53\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 786ms/step - loss: 0.0860\n",
      "Epoch 10/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.1013, acc: 0.6633333333333333, val_acc: 0.56\n",
      "Best improved: 0.545 -> 0.56. Saving model to models/siamese\n",
      "19/19 [==============================] - 15s 795ms/step - loss: 0.1013\n",
      "Epoch 11/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.1038, acc: 0.6283333333333333, val_acc: 0.565\n",
      "Best improved: 0.56 -> 0.565. Saving model to models/siamese\n",
      "19/19 [==============================] - 15s 787ms/step - loss: 0.1038\n",
      "Epoch 12/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0886, acc: 0.6833333333333333, val_acc: 0.56\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 795ms/step - loss: 0.0886\n",
      "Epoch 13/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0779, acc: 0.675, val_acc: 0.54\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 797ms/step - loss: 0.0779\n",
      "Epoch 14/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0745, acc: 0.6816666666666666, val_acc: 0.605\n",
      "Best improved: 0.565 -> 0.605. Saving model to models/siamese\n",
      "19/19 [==============================] - 15s 794ms/step - loss: 0.0745\n",
      "Epoch 15/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0760, acc: 0.71, val_acc: 0.61\n",
      "Best improved: 0.605 -> 0.61. Saving model to models/siamese\n",
      "19/19 [==============================] - 15s 792ms/step - loss: 0.0760\n",
      "Epoch 16/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0637, acc: 0.7166666666666667, val_acc: 0.55\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 774ms/step - loss: 0.0637\n",
      "Epoch 17/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0610, acc: 0.715, val_acc: 0.585\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 807ms/step - loss: 0.0610\n",
      "Epoch 18/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0612, acc: 0.7333333333333333, val_acc: 0.6\n",
      "Score did not improve\n",
      "19/19 [==============================] - 16s 841ms/step - loss: 0.0612\n",
      "Epoch 19/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0557, acc: 0.7183333333333334, val_acc: 0.55\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 779ms/step - loss: 0.0557\n",
      "Epoch 20/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0524, acc: 0.7, val_acc: 0.55\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 778ms/step - loss: 0.0524\n",
      "Epoch 21/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0588, acc: 0.7616666666666667, val_acc: 0.59\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 791ms/step - loss: 0.0588\n",
      "Epoch 22/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0407, acc: 0.7566666666666667, val_acc: 0.62\n",
      "Best improved: 0.61 -> 0.62. Saving model to models/siamese\n",
      "19/19 [==============================] - 15s 796ms/step - loss: 0.0407\n",
      "Epoch 23/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0537, acc: 0.7883333333333333, val_acc: 0.62\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 782ms/step - loss: 0.0537\n",
      "Epoch 24/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0501, acc: 0.805, val_acc: 0.6\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 780ms/step - loss: 0.0501\n",
      "Epoch 25/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0446, acc: 0.775, val_acc: 0.59\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 780ms/step - loss: 0.0446\n",
      "Epoch 26/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0464, acc: 0.79, val_acc: 0.585\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 781ms/step - loss: 0.0464\n",
      "Epoch 27/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0372, acc: 0.7783333333333333, val_acc: 0.56\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 783ms/step - loss: 0.0372\n",
      "Epoch 28/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0322, acc: 0.8133333333333334, val_acc: 0.59\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 784ms/step - loss: 0.0322\n",
      "Epoch 29/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0303, acc: 0.805, val_acc: 0.57\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 797ms/step - loss: 0.0303\n",
      "Epoch 30/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0328, acc: 0.8333333333333334, val_acc: 0.575\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 790ms/step - loss: 0.0328\n",
      "Epoch 31/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0285, acc: 0.8116666666666666, val_acc: 0.565\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 800ms/step - loss: 0.0285\n",
      "Epoch 32/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0262, acc: 0.8116666666666666, val_acc: 0.59\n",
      "Score did not improve\n",
      "19/19 [==============================] - 16s 845ms/step - loss: 0.0262\n",
      "Epoch 33/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0242, acc: 0.8033333333333333, val_acc: 0.63\n",
      "Best improved: 0.62 -> 0.63. Saving model to models/siamese\n",
      "19/19 [==============================] - 16s 833ms/step - loss: 0.0242\n",
      "Epoch 34/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0268, acc: 0.83, val_acc: 0.63\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 787ms/step - loss: 0.0268\n",
      "Epoch 35/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0252, acc: 0.8333333333333334, val_acc: 0.6\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 781ms/step - loss: 0.0252\n",
      "Epoch 36/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0260, acc: 0.8333333333333334, val_acc: 0.625\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 804ms/step - loss: 0.0260\n",
      "Epoch 37/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0285, acc: 0.8583333333333333, val_acc: 0.625\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 784ms/step - loss: 0.0285\n",
      "Epoch 38/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0291, acc: 0.8233333333333334, val_acc: 0.62\n",
      "Score did not improve\n",
      "19/19 [==============================] - 14s 755ms/step - loss: 0.0291\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 39/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0169, acc: 0.8516666666666667, val_acc: 0.575\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 705ms/step - loss: 0.0169\n",
      "Epoch 40/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0153, acc: 0.8583333333333333, val_acc: 0.605\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 706ms/step - loss: 0.0153\n",
      "Epoch 41/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0134, acc: 0.8433333333333334, val_acc: 0.57\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 704ms/step - loss: 0.0134\n",
      "Epoch 42/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0117, acc: 0.84, val_acc: 0.6\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 703ms/step - loss: 0.0117\n",
      "Epoch 43/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0154, acc: 0.8533333333333334, val_acc: 0.595\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 705ms/step - loss: 0.0154\n",
      "Epoch 44/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0168, acc: 0.85, val_acc: 0.58\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 706ms/step - loss: 0.0168\n",
      "Epoch 45/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0200, acc: 0.8683333333333333, val_acc: 0.585\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 709ms/step - loss: 0.0200\n",
      "Epoch 46/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0146, acc: 0.87, val_acc: 0.605\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 705ms/step - loss: 0.0146\n",
      "Epoch 47/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0120, acc: 0.8883333333333333, val_acc: 0.62\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 705ms/step - loss: 0.0120\n",
      "Epoch 48/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0046, acc: 0.8733333333333333, val_acc: 0.615\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 708ms/step - loss: 0.0046\n",
      "Epoch 49/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0106, acc: 0.8883333333333333, val_acc: 0.64\n",
      "Best improved: 0.63 -> 0.64. Saving model to models/siamese\n",
      "19/19 [==============================] - 14s 717ms/step - loss: 0.0106\n",
      "Epoch 50/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0117, acc: 0.895, val_acc: 0.605\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 706ms/step - loss: 0.0117\n",
      "Epoch 51/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0111, acc: 0.89, val_acc: 0.62\n",
      "Score did not improve\n",
      "19/19 [==============================] - 14s 719ms/step - loss: 0.0111\n",
      "Epoch 52/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0107, acc: 0.8783333333333333, val_acc: 0.615\n",
      "Score did not improve\n",
      "19/19 [==============================] - 14s 735ms/step - loss: 0.0107\n",
      "Epoch 53/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0111, acc: 0.875, val_acc: 0.63\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 708ms/step - loss: 0.0111\n",
      "Epoch 54/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0067, acc: 0.8783333333333333, val_acc: 0.61\n",
      "Score did not improve\n",
      "19/19 [==============================] - 14s 715ms/step - loss: 0.0067\n",
      "Epoch 55/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0114, acc: 0.8533333333333334, val_acc: 0.605\n",
      "Score did not improve\n",
      "19/19 [==============================] - 14s 750ms/step - loss: 0.0114\n",
      "Epoch 56/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0099, acc: 0.8683333333333333, val_acc: 0.61\n",
      "Score did not improve\n",
      "19/19 [==============================] - 14s 717ms/step - loss: 0.0099\n",
      "Epoch 57/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0183, acc: 0.88, val_acc: 0.61\n",
      "Score did not improve\n",
      "19/19 [==============================] - 14s 737ms/step - loss: 0.0183\n",
      "Epoch 58/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0109, acc: 0.875, val_acc: 0.665\n",
      "Best improved: 0.64 -> 0.665. Saving model to models/siamese\n",
      "19/19 [==============================] - 14s 755ms/step - loss: 0.0109\n",
      "Epoch 59/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0175, acc: 0.8416666666666667, val_acc: 0.68\n",
      "Best improved: 0.665 -> 0.68. Saving model to models/siamese\n",
      "19/19 [==============================] - 14s 733ms/step - loss: 0.0175\n",
      "Epoch 60/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0175, acc: 0.8516666666666667, val_acc: 0.58\n",
      "Score did not improve\n",
      "19/19 [==============================] - 14s 713ms/step - loss: 0.0175\n",
      "Epoch 61/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0170, acc: 0.8733333333333333, val_acc: 0.625\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 708ms/step - loss: 0.0170\n",
      "Epoch 62/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0081, acc: 0.8883333333333333, val_acc: 0.595\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 707ms/step - loss: 0.0081\n",
      "Epoch 63/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0081, acc: 0.8533333333333334, val_acc: 0.64\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 705ms/step - loss: 0.0081\n",
      "Epoch 64/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0085, acc: 0.8683333333333333, val_acc: 0.59\n",
      "Score did not improve\n",
      "19/19 [==============================] - 13s 704ms/step - loss: 0.0085\n",
      "Epoch 65/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0081, acc: 0.8683333333333333, val_acc: 0.585\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 774ms/step - loss: 0.0081\n",
      "Epoch 66/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0108, acc: 0.8683333333333333, val_acc: 0.64\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 813ms/step - loss: 0.0108\n",
      "Epoch 67/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0099, acc: 0.875, val_acc: 0.61\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 779ms/step - loss: 0.0099\n",
      "Epoch 68/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0080, acc: 0.8866666666666667, val_acc: 0.61\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 777ms/step - loss: 0.0080\n",
      "Epoch 69/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0096, acc: 0.8966666666666666, val_acc: 0.595\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 777ms/step - loss: 0.0096\n",
      "Epoch 70/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0073, acc: 0.8816666666666667, val_acc: 0.585\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 783ms/step - loss: 0.0073\n",
      "Epoch 71/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0101, acc: 0.9, val_acc: 0.58\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 810ms/step - loss: 0.0101\n",
      "Epoch 72/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0083, acc: 0.89, val_acc: 0.62\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 786ms/step - loss: 0.0083\n",
      "Epoch 73/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0056, acc: 0.8733333333333333, val_acc: 0.595\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 780ms/step - loss: 0.0056\n",
      "Epoch 74/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0064, acc: 0.8916666666666667, val_acc: 0.62\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 781ms/step - loss: 0.0064\n",
      "Epoch 75/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0094, acc: 0.89, val_acc: 0.6\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 786ms/step - loss: 0.0094\n",
      "Epoch 76/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0132, acc: 0.8933333333333333, val_acc: 0.605\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 800ms/step - loss: 0.0132\n",
      "Epoch 77/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0116, acc: 0.8866666666666667, val_acc: 0.625\n",
      "Score did not improve\n",
      "19/19 [==============================] - 15s 775ms/step - loss: 0.0116\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 78/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0155, acc: 0.895, val_acc: 0.61\n",
      "Score did not improve\n",
      "19/19 [==============================] - 16s 862ms/step - loss: 0.0155\n",
      "Epoch 79/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0144, acc: 0.89, val_acc: 0.615\n",
      "Score did not improve\n",
      "19/19 [==============================] - 17s 900ms/step - loss: 0.0144\n",
      "Epoch 80/1000\n",
      "19/19 [==============================] - ETA: 0s - loss: 0.0088, acc: 0.8583333333333333, val_acc: 0.61\n",
      "Score did not improve\n",
      "Score did not improve for 20 epochs. Stopping\n",
      "19/19 [==============================] - 17s 899ms/step - loss: 0.0088\n"
Jan Rudolf's avatar
Jan Rudolf committed
543 544 545 546 547 548 549
     ]
    }
   ],
   "source": [
    "tf.config.run_functions_eagerly(True)\n",
    "history = model.fit(\n",
    "    train_dataset,\n",
550 551
    "    epochs=1000,\n",
    "    callbacks=[CustomMetricCallback(patience=20, verbose=1)]\n",
Jan Rudolf's avatar
Jan Rudolf committed
552 553 554 555 556
    ")"
   ]
  },
  {
   "cell_type": "code",
557
   "execution_count": 11,
Jan Rudolf's avatar
Jan Rudolf committed
558 559 560 561 562
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
563 564 565 566
    "model = tf.keras.models.load_model(MODEL_PATH, compile=False)\n",
    "\n",
    "train_features = model.predict(train_dataset)\n",
    "\n",
Jan Rudolf's avatar
Jan Rudolf committed
567 568 569
    "knn = KNeighborsClassifier(n_neighbors=12)\n",
    "knn.fit(train_features, y_train)\n",
    "\n",
570 571
    "train_predictions = knn.predict(train_features)\n",
    "\n",
Jan Rudolf's avatar
Jan Rudolf committed
572 573 574 575 576 577
    "test_features = model.predict(test_dataset)\n",
    "test_predictions = knn.predict(test_features)"
   ]
  },
  {
   "cell_type": "code",
578
   "execution_count": 12,
Jan Rudolf's avatar
Jan Rudolf committed
579 580 581 582 583 584
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603
      "=== TRAIN ===\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "       blues       0.77      0.92      0.84        60\n",
      "   classical       0.98      1.00      0.99        60\n",
      "     country       0.77      0.85      0.81        60\n",
      "       disco       0.78      0.85      0.82        60\n",
      "      hiphop       0.77      0.77      0.77        60\n",
      "        jazz       0.92      0.95      0.93        60\n",
      "       metal       0.89      0.98      0.94        60\n",
      "         pop       0.93      0.83      0.88        60\n",
      "      reggae       0.79      0.70      0.74        60\n",
      "        rock       0.81      0.57      0.67        60\n",
      "\n",
      "    accuracy                           0.84       600\n",
      "   macro avg       0.84      0.84      0.84       600\n",
      "weighted avg       0.84      0.84      0.84       600\n",
      "\n",
      "=== TEST ===\n",
Jan Rudolf's avatar
Jan Rudolf committed
604 605
      "              precision    recall  f1-score   support\n",
      "\n",
606 607 608 609 610 611 612 613 614 615
      "       blues       0.71      0.75      0.73        20\n",
      "   classical       0.94      0.85      0.89        20\n",
      "     country       0.52      0.70      0.60        20\n",
      "       disco       0.67      0.40      0.50        20\n",
      "      hiphop       0.50      0.65      0.57        20\n",
      "        jazz       0.71      0.75      0.73        20\n",
      "       metal       0.73      0.95      0.83        20\n",
      "         pop       0.75      0.75      0.75        20\n",
      "      reggae       0.47      0.40      0.43        20\n",
      "        rock       0.42      0.25      0.31        20\n",
Jan Rudolf's avatar
Jan Rudolf committed
616
      "\n",
617 618 619
      "    accuracy                           0.65       200\n",
      "   macro avg       0.64      0.65      0.63       200\n",
      "weighted avg       0.64      0.65      0.63       200\n",
Jan Rudolf's avatar
Jan Rudolf committed
620 621 622 623 624
      "\n"
     ]
    },
    {
     "data": {
625 626 627 628 629 630 631 632 633 634 635 636 637
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
Jan Rudolf's avatar
Jan Rudolf committed
638 639 640 641 642 643 644 645 646 647 648 649
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report, plot_confusion_matrix\n",
650 651 652 653 654 655 656 657
    "print('=== TRAIN ===')\n",
    "print(classification_report(y_train, train_predictions, target_names=le.classes_))\n",
    "disp = plot_confusion_matrix(knn, train_features, y_train,\n",
    "                             display_labels=le.classes_,\n",
    "                             cmap=plt.cm.Blues)\n",
    "disp.ax_.set_title('Train confusion matrix');\n",
    "\n",
    "print('=== TEST ===')\n",
Jan Rudolf's avatar
Jan Rudolf committed
658 659 660 661
    "print(classification_report(y_test, test_predictions, target_names=le.classes_))\n",
    "disp = plot_confusion_matrix(knn, test_features, y_test,\n",
    "                             display_labels=le.classes_,\n",
    "                             cmap=plt.cm.Blues)\n",
662
    "disp.ax_.set_title('Test confusion matrix');"
Jan Rudolf's avatar
Jan Rudolf committed
663 664 665 666
   ]
  },
  {
   "cell_type": "code",
667
   "execution_count": 13,
Jan Rudolf's avatar
Jan Rudolf committed
668 669 670 671 672 673 674 675 676
   "metadata": {},
   "outputs": [],
   "source": [
    "fit = umap.UMAP()\n",
    "u = fit.fit_transform(train_features)"
   ]
  },
  {
   "cell_type": "code",
677
   "execution_count": 30,
Jan Rudolf's avatar
Jan Rudolf committed
678 679 680 681
   "metadata": {},
   "outputs": [
    {
     "data": {
682
      "image/png": "\n",
Jan Rudolf's avatar
Jan Rudolf committed
683 684 685 686 687 688 689 690 691 692 693
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
694 695 696
    "scatter = plt.scatter(u[:,0], u[:,1], c=y_train, cmap='tab10')\n",
    "plt.title('UMAP embedding')\n",
    "handles, labels = scatter.legend_elements()\n",
Jan Rudolf's avatar
Jan Rudolf committed
697
    "\n",
698
    "plt.legend(handles, le.classes_);"
Jan Rudolf's avatar
Jan Rudolf committed
699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}