diff --git a/Ch14_Computer_Vision/.idea/.gitignore b/Ch14_Computer_Vision/.idea/.gitignore new file mode 100644 index 00000000..35410cac --- /dev/null +++ b/Ch14_Computer_Vision/.idea/.gitignore @@ -0,0 +1,8 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/Ch14_Computer_Vision/.idea/Ch14_Computer_Vision.iml b/Ch14_Computer_Vision/.idea/Ch14_Computer_Vision.iml new file mode 100644 index 00000000..f571432d --- /dev/null +++ b/Ch14_Computer_Vision/.idea/Ch14_Computer_Vision.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/Ch14_Computer_Vision/.idea/deployment.xml b/Ch14_Computer_Vision/.idea/deployment.xml new file mode 100644 index 00000000..aba64b06 --- /dev/null +++ b/Ch14_Computer_Vision/.idea/deployment.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/Ch14_Computer_Vision/.idea/inspectionProfiles/profiles_settings.xml b/Ch14_Computer_Vision/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 00000000..105ce2da --- /dev/null +++ b/Ch14_Computer_Vision/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/Ch14_Computer_Vision/.idea/misc.xml b/Ch14_Computer_Vision/.idea/misc.xml new file mode 100644 index 00000000..db8786c0 --- /dev/null +++ b/Ch14_Computer_Vision/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/Ch14_Computer_Vision/.idea/modules.xml b/Ch14_Computer_Vision/.idea/modules.xml new file mode 100644 index 00000000..9b389ad5 --- /dev/null +++ b/Ch14_Computer_Vision/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/Ch14_Computer_Vision/.idea/vcs.xml b/Ch14_Computer_Vision/.idea/vcs.xml new file mode 100644 index 00000000..6c0b8635 --- /dev/null +++ b/Ch14_Computer_Vision/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/Ch14_Computer_Vision/Single_Shot_Multibox_Detection.ipynb b/Ch14_Computer_Vision/Single_Shot_Multibox_Detection.ipynb index e5683066..ae0b4882 100644 --- a/Ch14_Computer_Vision/Single_Shot_Multibox_Detection.ipynb +++ b/Ch14_Computer_Vision/Single_Shot_Multibox_Detection.ipynb @@ -640,12 +640,25 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "id_cat = dict()\n", - "id_cat[0] = 'pikachu'\n", + "id_cat[0] = 'pikachu'" + ] + }, + { + "metadata": { + "jupyter": { + "is_executing": true + } + }, + "cell_type": "code", + "source": [ + "\n", + "import torch\n", + "import torch.nn as nn\n", "\n", "class FocalLoss(nn.Module):\n", " def __init__(self, alpha=0.25, gamma=2, device=\"cuda:0\", eps=1e-10):\n", @@ -658,19 +671,29 @@ " def forward(self, input, target):\n", " p = torch.sigmoid(input)\n", " pt = p * target.float() + (1.0 - p) * (1 - target).float()\n", - " alpha_t = (1.0 - self.alpha) * target.float() + self.alpha * (1 - target).float()\n", - " loss = - 1.0 * torch.pow((1 - pt), self.gamma) * torch.log(pt + self.eps)\n", - " return loss.sum()\n", - " \n", + " alpha_t = self.alpha * target.float() + (1.0 - self.alpha) * (1 - target).float()\n", + " loss = - alpha_t * torch.pow((1 - pt), self.gamma) * torch.log(pt + self.eps)\n", + " return loss.sum()\n" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": 21, + "source": [ + "\n", "class SSDLoss(nn.Module):\n", " def __init__(self, loc_factor, jaccard_overlap, device = \"cuda:0\", **kwargs):\n", " super().__init__()\n", " self.fl = FocalLoss(**kwargs)\n", " self.loc_factor = loc_factor\n", " self.jaccard_overlap = jaccard_overlap\n", - " \n", "\n", - " \n", + "\n", + "\n", " self.device = device\n", "\n", " def one_hot_encoding(labels, num_classes):\n", @@ -684,23 +707,23 @@ " torch.log((x[:, 3:4] / anchors[overlap_indicies, 3:4]))\n", " ], dim=1)\n", "\n", - " def forward(self, class_hat, bb_hat, class_true, bb_true, anchors): \n", + " def forward(self, class_hat, bb_hat, class_true, bb_true, anchors):\n", " loc_loss = 0.0\n", " class_loss = 0.0\n", - " \n", - " \n", + "\n", + "\n", " for i in range(len(class_true)): # Batch level\n", "\n", " class_hat_i = class_hat[i, :, :]\n", "\n", " bb_true_i = bb_true[i].float()\n", - " \n", - " class_true_i = class_true[i] \n", - " \n", + "\n", + " class_true_i = class_true[i]\n", + "\n", " class_target = torch.zeros(class_hat_i.shape[0]).long().to(self.device)\n", - " \n", + "\n", " overlap_list = d2l.find_overlap(bb_true_i.squeeze(0), anchors, self.jaccard_overlap)\n", - " \n", + "\n", " temp_loc_loss = 0.0\n", " for j in range(len(overlap_list)): # BB level\n", " overlap = overlap_list[j]\n",