\n",
"Micro-averaging: In micro-averaging, we aggregate the binary metrics over all the classes. This means that we sum up the number of true positives, false positives, false negatives, and true negatives over all the classes and then compute the performance metrics.
\n",
"Macro-averaging: In macro-averaging, we compute the binary metrics, e.g. precision, for each class separately and then take the average across all the classes. Macro-averaging gives equal weight to all classes and is suitable when we want to evaluate the performance of each class separately.
\n",
"Weighted averaging: In weighted averaging, we compute the binary metrics for each class separately and then take the weighted average across all the classes. The weight of each class is proportional to the number of samples from that class in the dataset. Weighted averaging is most suitable averaging technique when the dataset is unbalanced among the various classes.
\n",
"
\n",
"\n",
"\n",
"We can use the function below to compute multi-class performance measures, aggregating the corresponding binary metrics with weighted averaging. "
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"deletable": false,
"editable": false,
"nbgrader": {
"cell_type": "code",
"checksum": "e19e1aa94b31762ba50c300b1066f333",
"grade": false,
"grade_id": "cell-d95e6fcea9692d4d",
"locked": true,
"schema_version": 3,
"solution": false,
"task": false
}
},
"outputs": [],
"source": [
"##### DO NOT CHANGE #####\n",
"\n",
"def compute_multiclass_metrics(knn, ground_truth, threshold=.3, labels=np.arange(10)):\n",
" metrics_list = []\n",
" for label in labels:\n",
" metrics_list.append(compute_binary_metrics(knn, label, ground_truth, threshold))\n",
" weights = np.asarray([(ground_truth==label).sum() for label in labels])\n",
" multiclass_metrics = {}\n",
" for m in ['precision', 'recall', 'fall-out']:\n",
" multiclass_metrics[m] = sum([weights[i]*metrics_list[i][m] for i in range(len(labels))])/weights.sum()\n",
" return multiclass_metrics, metrics_list\n",
"\n",
"##### DO NOT CHANGE #####"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Multiclass knn performance, weighted average: \n",
"\t-precision: 0.84\n",
"\t-recall: 0.64\n",
"\t-fall-out: 0.02\n"
]
}
],
"source": [
"knn.set_k(5)\n",
"metrics, _ = compute_multiclass_metrics(knn, y_unknown, threshold=.3)\n",
"print(f'Multiclass knn performance, weighted average: ')\n",
"for k, v in metrics.items():\n",
" print(f'\\t-{k}: {v:.2f}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As stated above, the ROC curve is generated by calculating the TPR and FPR for different threshold values, ranging from 0 to 1. As the threshold increases, the model becomes more conservative, classifying fewer instances as positive, which leads to a decrease in the FPR and TPR. Conversely, as the threshold decreases, the model becomes more aggressive, classifying more instances as positive, which leads to an increase in the FPR and TPR.\n",
"\n",
"A perfect classifier would have a ROC curve that passes through the top-left corner of the plot (TPR=1, FPR=0), indicating a TPR of 100% and an FPR of 0%. The area under the ROC curve (AUC) is a measure of the overall performance of the classifier, with a value of 1 indicating a perfect classifier and a value of 0.5 indicating a classifier that is no better than random guessing.\n",
"\n",
"The following code plots the Recall VS Precision (RvP) curve and the ROC curve (with AUC), so that we can see how modifying the threshold gives us classifiers with different trade-offs for the metrics."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"deletable": false,
"editable": false,
"nbgrader": {
"cell_type": "code",
"checksum": "03e3f6d3713e1ec9a7b7ecff034d6ab1",
"grade": false,
"grade_id": "cell-20702f8b17c1bd95",
"locked": true,
"schema_version": 3,
"solution": false,
"task": false
}
},
"outputs": [],
"source": [
"##### DO NOT CHANGE #####\n",
"\n",
"def plot_measures(knn, y_unknown, plot_type='ROC', all=True):\n",
" measures = {'ROC':('fall-out', 'recall'), 'RvP': ('recall', 'precision')}\n",
" measures = measures[plot_type]\n",
" m0 = {k:[] for k in ['mc']+list(range(10))}\n",
" m1 = {k:[] for k in ['mc']+list(range(10))}\n",
" for threshold in np.arange(0,1,.1):\n",
" mc, bin = compute_multiclass_metrics(\n",
" knn, \n",
" y_unknown, \n",
" threshold)\n",
" m0['mc'].append(mc[measures[0]])\n",
" m1['mc'].append(mc[measures[1]])\n",
" for i in range(10):\n",
" m0[i].append(bin[i][measures[0]])\n",
" m1[i].append(bin[i][measures[1]])\n",
" if all:\n",
" for i in range(10):\n",
" plt.plot(m0[i][::-1], m1[i][::-1], label='digit_'+str(i))\n",
" plt.plot(m0['mc'][::-1], m1['mc'][::-1], label='multiclass')\n",
" auc = ''\n",
" if plot_type=='ROC':\n",
" plt.plot([0, 1], [0, 1], linestyle='dashed', color='red', label='random_guess')\n",
" auc = np.abs(np.trapz(y=m1['mc'][::-1], x=m0['mc'][::-1]))\n",
" auc = f', AUC: {auc:.3f}'\n",
" plt.xlim(0, 1)\n",
" plt.ylim(0, 1)\n",
" plt.xlabel(measures[0])\n",
" plt.ylabel(measures[1])\n",
" plt.title(plot_type+auc)\n",
" plt.legend()\n",
" plt.grid()\n",
" plt.show()\n",
" plt.close()\n",
"\n",
"##### DO NOT CHANGE #####"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": false,
"editable": false,
"nbgrader": {
"cell_type": "markdown",
"checksum": "8c7d14c3da4a8ef2a65620bd6c08a34d",
"grade": false,
"grade_id": "cell-ba8f5378c673ceef",
"locked": true,
"schema_version": 3,
"solution": false,
"task": false
}
},
"source": [
"Let's plot some of the curves varying the number of neighbors we use for the classification."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"K: 2\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"