15 lines
649 B
Python
15 lines
649 B
Python
import torch.nn as nn
|
|
import torch
|
|
class DFL(nn.Module):
|
|
# Integral module of Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
|
|
def __init__(self, c1=16):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
|
|
x = torch.arange(c1, dtype=torch.float)
|
|
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
|
|
self.c1 = c1
|
|
|
|
def forward(self, x):
|
|
b, c, a = x.shape # batch, channels, anchors
|
|
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
|