Reference article:
Baidu AI strategy: Taxi ticket identification_ Can I waste my blog - CSDN blog
(complete python source code attached) entry case based on tensorflow and opencv_ Invoice identification I: positioning of key areas_ Xiaobai comes to move - CSDN blog_ python invoice identification
Note: thanks to the teammates who completed the project together
Official website of the competition: Taxi invoice identification Competitions - DataFountain
1, Description of competition questions
1. Competition background
Taxi invoices are common in daily financial invoice reimbursement. Because these invoices have rich styles, obvious regional characteristics, and contain a large number of fuzzy and misplaced handwriting, it is very important to accurately locate the text field of the invoice, accurately identify the text and structured output.
2. Competition task
The task of this competition is to train the text detection and recognition model of taxi invoice by using image processing, machine learning, deep learning and other methods, and realize the structured output of recognition results.
3. Data introduction
The data comes from the reimbursement taxi invoices in actual production and life.
4. Data description
Provide more than 10 invoice data samples from different regions so that contestants can be familiar with the invoice style of test taxis. The amount of validation set data provided shall not be less than 500.
The data includes two parts: the first part is the whole invoice picture, which is used to test the text detection algorithm and the final structured output effect; The second part is the field screenshot taken from the invoice picture, which is used to test the effect of the character recognition algorithm (text recognition according to the cut picture).
5. Evaluation criteria
The standards of text detection include precision, recall and F-measure; The standards of character recognition include statistical recognition accuracy by whole field and statistical recognition accuracy by character; The standard of structured output is to calculate the recognition accuracy of output fields by whole fields and the recognition accuracy of output characters by characters.
2, Solution ideas
Selection of deep learning framework: PyTorch; Text detection algorithm selection: CTPN framework; Character recognition algorithm selection: CRNN+CTC framework.
The specific ideas are as follows: firstly, CNN extracts the image convolution features, then LSTM further extracts the sequence features in the image convolution features, and finally introduces CTC to solve the problem that the characters cannot be aligned during training.
3, Experimental process
1. Build pytorch framework
1) Corresponding versions of CH, CH, C, C and C:
2) Preparations: Download and install anaconda and pychar;
cuda and cudnn download and install;
Download Torch and Torch vision;
The source of conda and pip has been switched to Tsinghua mirror source;
(conda and pip are updated to domestic sources_ lazy_boy's blog - CSDN blog)
A virtual environment with python version 3.8 has been established in conda;
(conda establishes Python 36 virtual environment_ lazy_boy's blog - CSDN blog)
3) Installation and testing torch:
Open the new virtual environment under anaconda command line, install the wheel with pip, and conduct import test
4) Installing pytorch in a virtual environment
5) Create a test project in pycharm: you can get the version of cuda and cudnn
2. Text detection
Main principle: VGG extracts features, BLSTM integrates context information, and completes detection based on RPN
The specific steps are as follows:
1) Write dataset Py function to complete data preprocessing.
import os import xml.etree.ElementTree as ET import numpy as np import cv2 from torch.utils.data import Dataset import torch from config import IMAGE_MEAN from ctpn_utils import cal_rpn def readxml(path): gtboxes = [] imgfile = '' xml = ET.parse(path) for elem in xml.iter(): if 'filename' in elem.tag: imgfile = elem.text if 'object' in elem.tag: for attr in list(elem): if 'bndbox' in attr.tag: xmin = int(round(float(attr.find('xmin').text))) ymin = int(round(float(attr.find('ymin').text))) xmax = int(round(float(attr.find('xmax').text))) ymax = int(round(float(attr.find('ymax').text))) gtboxes.append((xmin, ymin, xmax, ymax)) return np.array(gtboxes), imgfile # for ctpn text detection class VOCDataset(Dataset): def __init__(self, datadir, labelsdir): ''' :param txtfile: image name list text file :param datadir: image's directory :param labelsdir: annotations' directory ''' if not os.path.isdir(datadir): raise Exception('[ERROR] {} is not a directory'.format(datadir)) if not os.path.isdir(labelsdir): raise Exception('[ERROR] {} is not a directory'.format(labelsdir)) self.datadir = datadir self.img_names = os.listdir(self.datadir) self.labelsdir = labelsdir def __len__(self): return len(self.img_names) def __getitem__(self, idx): img_name = self.img_names[idx] img_path = os.path.join(self.datadir, img_name) print(img_path) xml_path = os.path.join(self.labelsdir, img_name.replace('.jpg', '.xml')) gtbox, _ = readxml(xml_path) img = cv2.imread(img_path) h, w, c = img.shape # clip image if np.random.randint(2) == 1: img = img[:, ::-1, :] newx1 = w - gtbox[:, 2] - 1 newx2 = w - gtbox[:, 0] - 1 gtbox[:, 0] = newx1 gtbox[:, 2] = newx2 [cls, regr], _ = cal_rpn((h, w), (int(h / 16), int(w / 16)), 16, gtbox) m_img = img - IMAGE_MEAN regr = np.hstack([cls.reshape(cls.shape[0], 1), regr]) cls = np.expand_dims(cls, axis=0) # transform to torch tensor m_img = torch.from_numpy(m_img.transpose([2, 0, 1])).float() cls = torch.from_numpy(cls).float() regr = torch.from_numpy(regr).float() return m_img, cls, regr class ICDARDataset(Dataset): def __init__(self, datadir, labelsdir): ''' :param txtfile: image name list text file :param datadir: image's directory :param labelsdir: annotations' directory ''' if not os.path.isdir(datadir): raise Exception('[ERROR] {} is not a directory'.format(datadir)) if not os.path.isdir(labelsdir): raise Exception('[ERROR] {} is not a directory'.format(labelsdir)) self.datadir = datadir self.img_names = os.listdir(self.datadir) self.labelsdir = labelsdir def __len__(self): return len(self.img_names) def box_transfer(self,coor_lists,rescale_fac = 1.0): gtboxes = [] for coor_list in coor_lists: coors_x = [int(coor_list[2*i]) for i in range(4)] coors_y = [int(coor_list[2*i+1]) for i in range(4)] xmin = min(coors_x) xmax = max(coors_x) ymin = min(coors_y) ymax = max(coors_y) if rescale_fac>1.0: xmin = int(xmin / rescale_fac) xmax = int(xmax / rescale_fac) ymin = int(ymin / rescale_fac) ymax = int(ymax / rescale_fac) gtboxes.append((xmin, ymin, xmax, ymax)) return np.array(gtboxes) def box_transfer_v2(self,coor_lists,rescale_fac = 1.0): gtboxes = [] for coor_list in coor_lists: coors_x = [int(coor_list[2 * i]) for i in range(4)] coors_y = [int(coor_list[2 * i + 1]) for i in range(4)] xmin = min(coors_x) xmax = max(coors_x) ymin = min(coors_y) ymax = max(coors_y) if rescale_fac > 1.0: xmin = int(xmin / rescale_fac) xmax = int(xmax / rescale_fac) ymin = int(ymin / rescale_fac) ymax = int(ymax / rescale_fac) prev = xmin for i in range(xmin // 16 + 1, xmax // 16 + 1): next = 16*i-0.5 gtboxes.append((prev, ymin, next, ymax)) prev = next gtboxes.append((prev, ymin, xmax, ymax)) return np.array(gtboxes) def parse_gtfile(self,gt_path,rescale_fac = 1.0): coor_lists = list() with open(gt_path) as f: content = f.readlines() for line in content: coor_list = line.split(',')[:8] if len(coor_list)==8: coor_lists.append(coor_list) return self.box_transfer_v2(coor_lists,rescale_fac) def draw_boxes(self,img,cls,base_anchors,gt_box): for i in range(len(cls)): if cls[i]==1: pt1 = (int(base_anchors[i][0]),int(base_anchors[i][1])) pt2 = (int(base_anchors[i][2]),int(base_anchors[i][3])) img = cv2.rectangle(img,pt1,pt2,(200,100,100)) for i in range(gt_box.shape[0]): pt1 = (int(gt_box[i][0]),int(gt_box[i][1])) pt2 = (int(gt_box[i][2]),int(gt_box[i][3])) img = cv2.rectangle(img, pt1, pt2, (100, 200, 100)) return img def __getitem__(self, idx): img_name = self.img_names[idx] img_path = os.path.join(self.datadir, img_name) # print(img_path) img = cv2.imread(img_path) #####for read error, use default image##### if img is None: print(img_path) with open('error_imgs.txt','a') as f: f.write('{}\n'.format(img_path)) img_name = 'img_2647.jpg' img_path = os.path.join(self.datadir, img_name) img = cv2.imread(img_path) #####for read error, use default image##### h, w, c = img.shape rescale_fac = max(h, w) / 1600 if rescale_fac>1.0: h = int(h/rescale_fac) w = int(w/rescale_fac) img = cv2.resize(img,(w,h)) gt_path = os.path.join(self.labelsdir, 'gt_'+img_name.split('.')[0]+'.txt') gtbox = self.parse_gtfile(gt_path,rescale_fac) # clip image if np.random.randint(2) == 1: img = img[:, ::-1, :] newx1 = w - gtbox[:, 2] - 1 newx2 = w - gtbox[:, 0] - 1 gtbox[:, 0] = newx1 gtbox[:, 2] = newx2 [cls, regr], base_anchors = cal_rpn((h, w), (int(h / 16), int(w / 16)), 16, gtbox) # debug_img = self.draw_boxes(img.copy(),cls,base_anchors,gtbox) # cv2.imwrite('debug/{}'.format(img_name),debug_img) m_img = img - IMAGE_MEAN regr = np.hstack([cls.reshape(cls.shape[0], 1), regr]) cls = np.expand_dims(cls, axis=0) # transform to torch tensor m_img = torch.from_numpy(m_img.transpose([2, 0, 1])).float() cls = torch.from_numpy(cls).float() regr = torch.from_numpy(regr).float() return m_img, cls, regr if __name__ == '__main__': xmin = 15 xmax = 95 for i in range(xmin//16+1,xmax//16+1): print(16*i-0.5)
2) Import ctpn network architecture and initialize the architecture
class RPN_REGR_Loss(nn.Module): def __init__(self, device, sigma=9.0): super(RPN_REGR_Loss, self).__init__() self.sigma = sigma self.device = device def forward(self, input, target): ''' smooth L1 loss :param input:y_preds :param target: y_true :return: ''' try: cls = target[0, :, 0] regr = target[0, :, 1:3] # apply regression to positive sample regr_keep = (cls == 1).nonzero()[:, 0] regr_true = regr[regr_keep] regr_pred = input[0][regr_keep] diff = torch.abs(regr_true - regr_pred) less_one = (diff<1.0/self.sigma).float() loss = less_one * 0.5 * diff ** 2 * self.sigma + torch.abs(1- less_one) * (diff - 0.5/self.sigma) loss = torch.sum(loss, 1) loss = torch.mean(loss) if loss.numel() > 0 else torch.tensor(0.0) except Exception as e: print('RPN_REGR_Loss Exception:', e) # print(input, target) loss = torch.tensor(0.0) return loss.to(self.device) class RPN_CLS_Loss(nn.Module): def __init__(self,device): super(RPN_CLS_Loss, self).__init__() self.device = device self.L_cls = nn.CrossEntropyLoss(reduction='none') # self.L_regr = nn.SmoothL1Loss() # self.L_refi = nn.SmoothL1Loss() self.pos_neg_ratio = 3 def forward(self, input, target): if config.OHEM: cls_gt = target[0][0] num_pos = 0 loss_pos_sum = 0 # print(len((cls_gt == 0).nonzero()),len((cls_gt == 1).nonzero())) if len((cls_gt == 1).nonzero())!=0: # avoid num of pos sample is 0 cls_pos = (cls_gt == 1).nonzero()[:, 0] gt_pos = cls_gt[cls_pos].long() cls_pred_pos = input[0][cls_pos] # print(cls_pred_pos.shape) loss_pos = self.L_cls(cls_pred_pos.view(-1, 2), gt_pos.view(-1)) loss_pos_sum = loss_pos.sum() num_pos = len(loss_pos) cls_neg = (cls_gt == 0).nonzero()[:, 0] gt_neg = cls_gt[cls_neg].long() cls_pred_neg = input[0][cls_neg] loss_neg = self.L_cls(cls_pred_neg.view(-1, 2), gt_neg.view(-1)) loss_neg_topK, _ = torch.topk(loss_neg, min(len(loss_neg), config.RPN_TOTAL_NUM-num_pos)) loss_cls = loss_pos_sum+loss_neg_topK.sum() loss_cls = loss_cls/config.RPN_TOTAL_NUM return loss_cls.to(self.device) else: y_true = target[0][0] cls_keep = (y_true != -1).nonzero()[:, 0] cls_true = y_true[cls_keep].long() cls_pred = input[0][cls_keep] loss = F.nll_loss(F.log_softmax(cls_pred, dim=-1), cls_true) # original is sparse_softmax_cross_entropy_with_logits # loss = nn.BCEWithLogitsLoss()(cls_pred[:,0], cls_true.float()) # 18-12-8 loss = torch.clamp(torch.mean(loss), 0, 10) if loss.numel() > 0 else torch.tensor(0.0) return loss.to(self.device) class basic_conv(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=True): super(basic_conv, self).__init__() self.out_channels = out_planes self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None self.relu = nn.ReLU(inplace=True) if relu else None def forward(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.relu is not None: x = self.relu(x) return x class CTPN_Model(nn.Module): def __init__(self): super().__init__() base_model = models.vgg16(pretrained=False) layers = list(base_model.features)[:-1] self.base_layers = nn.Sequential(*layers) # block5_conv3 output self.rpn = basic_conv(512, 512, 3, 1, 1, bn=False) self.brnn = nn.GRU(512,128, bidirectional=True, batch_first=True) self.lstm_fc = basic_conv(256, 512, 1, 1, relu=True, bn=False) self.rpn_class = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False) self.rpn_regress = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False) def forward(self, x): x = self.base_layers(x) # rpn x = self.rpn(x) #[b, c, h, w] x1 = x.permute(0,2,3,1).contiguous() # channels last [b, h, w, c] b = x1.size() # b, h, w, c x1 = x1.view(b[0]*b[1], b[2], b[3]) x2, _ = self.brnn(x1) xsz = x.size() x3 = x2.view(xsz[0], xsz[2], xsz[3], 256) # torch.Size([4, 20, 20, 256]) x3 = x3.permute(0,3,1,2).contiguous() # channels first [b, c, h, w] x3 = self.lstm_fc(x3) x = x3 cls = self.rpn_class(x) regr = self.rpn_regress(x) cls = cls.permute(0,2,3,1).contiguous() regr = regr.permute(0,2,3,1).contiguous() cls = cls.view(cls.size(0), cls.size(1)*cls.size(2)*10, 2) regr = regr.view(regr.size(0), regr.size(1)*regr.size(2)*10, 2) return cls, regr
3) Generate character recognition candidate box
4) Model training
import os os.environ['CUDA_VISIBLE_DEVICES'] = '' import cv2 import numpy as np import torch import torch.nn.functional as F from ctpn_model import CTPN_Model from ctpn_utils import gen_anchor, bbox_transfor_inv, clip_box, filter_bbox,nms, TextProposalConnectorOriented from ctpn_utils import resize import config prob_thresh = 0.5 width = 960 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') weights = os.path.join(config.checkpoints_dir, 'v3_ctpn_ep30_0.3699_0.0929_0.4628.pth')#'ctpn_ep17_0.0544_0.1125_0.1669.pth') model = CTPN_Model() model.load_state_dict(torch.load(weights, map_location=device)['model_state_dict']) model.to(device) model.eval() def dis(image): cv2.imshow('image', image) cv2.waitKey(0) cv2.destroyAllWindows() def get_det_boxes(image,display = True): image = resize(image, height=720) image_c = image.copy() h, w = image.shape[:2] image = image.astype(np.float32) - config.IMAGE_MEAN image = torch.from_numpy(image.transpose(2, 0, 1)).unsqueeze(0).float() with torch.no_grad(): image = image.to(device) cls, regr = model(image) cls_prob = F.softmax(cls, dim=-1).cpu().numpy() regr = regr.cpu().numpy() anchor = gen_anchor((int(h / 16), int(w / 16)), 16) bbox = bbox_transfor_inv(anchor, regr) bbox = clip_box(bbox, [h, w]) # print(bbox.shape) fg = np.where(cls_prob[0, :, 1] > prob_thresh)[0] # print(np.max(cls_prob[0, :, 1])) select_anchor = bbox[fg, :] select_score = cls_prob[0, fg, 1] select_anchor = select_anchor.astype(np.int32) # print(select_anchor.shape) keep_index = filter_bbox(select_anchor, 16) # nms select_anchor = select_anchor[keep_index] select_score = select_score[keep_index] select_score = np.reshape(select_score, (select_score.shape[0], 1)) nmsbox = np.hstack((select_anchor, select_score)) keep = nms(nmsbox, 0.3) # print(keep) select_anchor = select_anchor[keep] select_score = select_score[keep] # text line- textConn = TextProposalConnectorOriented() text = textConn.get_text_lines(select_anchor, select_score, [h, w]) print(text) if display: for i in text: s = str(round(i[-1] * 100, 2)) + '%' i = [int(j) for j in i] cv2.line(image_c, (i[0], i[1]), (i[2], i[3]), (0, 0, 255), 2) cv2.line(image_c, (i[0], i[1]), (i[4], i[5]), (0, 0, 255), 2) cv2.line(image_c, (i[6], i[7]), (i[2], i[3]), (0, 0, 255), 2) cv2.line(image_c, (i[4], i[5]), (i[6], i[7]), (0, 0, 255), 2) cv2.putText(image_c, s, (i[0]+13, i[1]+13), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,0,0), 2, cv2.LINE_AA) return text,image_c if __name__ == '__main__': img_path = 'images/t1.png' image = cv2.imread(img_path) text,image = get_det_boxes(image) cv2.imwrite('results/t.jpg',image) # dis(image)
3. Character recognition
CRNN algorithm is mainly used, which is mainly composed of CNN, RNN and CTC, corresponding to convolution layer, circulation layer and transcription layer respectively. Firstly, after extracting the image features through CNN, RNN is used to predict the sequence, and finally the final result is obtained through a CTC translation layer.
CNN adopts the classic VGG16, and the RNN part uses bidirectional LSTM. Note that the input received by the LSTM unit in pytoch must be three-dimensional Tensors, and each dimension represents different meanings.
CRNN part code:
import torch.nn as nn from collections import OrderedDict class BidirectionalLSTM(nn.Module): def __init__(self, nIn, nHidden, nOut): super(BidirectionalLSTM, self).__init__() self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) self.embedding = nn.Linear(nHidden * 2, nOut) def forward(self, input): recurrent, _ = self.rnn(input) T, b, h = recurrent.size() t_rec = recurrent.view(T * b, h) output = self.embedding(t_rec) # [T * b, nOut] output = output.view(T, b, -1) return output class CRNN(nn.Module): def __init__(self, imgH, nc, nclass, nh, leakyRelu=False): super(CRNN, self).__init__() assert imgH % 16 == 0, 'imgH has to be a multiple of 16' # 1x32x128 self.conv1 = nn.Conv2d(nc, 64, 3, 1, 1) self.relu1 = nn.ReLU(True) self.pool1 = nn.MaxPool2d(2, 2) # 64x16x64 self.conv2 = nn.Conv2d(64, 128, 3, 1, 1) self.relu2 = nn.ReLU(True) self.pool2 = nn.MaxPool2d(2, 2) # 128x8x32 self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1) self.bn3 = nn.BatchNorm2d(256) self.relu3_1 = nn.ReLU(True) self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1) self.relu3_2 = nn.ReLU(True) self.pool3 = nn.MaxPool2d((2, 2), (2, 1), (0, 1)) # 256x4x16 self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1) self.bn4 = nn.BatchNorm2d(512) self.relu4_1 = nn.ReLU(True) self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1) self.relu4_2 = nn.ReLU(True) self.pool4 = nn.MaxPool2d((2, 2), (2, 1), (0, 1)) # 512x2x16 self.conv5 = nn.Conv2d(512, 512, 2, 1, 0) self.bn5 = nn.BatchNorm2d(512) self.relu5 = nn.ReLU(True) # 512x1x16 self.rnn = nn.Sequential( BidirectionalLSTM(512, nh, nh), BidirectionalLSTM(nh, nh, nclass)) def forward(self, input): # conv features x = self.pool1(self.relu1(self.conv1(input))) x = self.pool2(self.relu2(self.conv2(x))) x = self.pool3(self.relu3_2(self.conv3_2(self.relu3_1(self.bn3(self.conv3_1(x)))))) x = self.pool4(self.relu4_2(self.conv4_2(self.relu4_1(self.bn4(self.conv4_1(x)))))) conv = self.relu5(self.bn5(self.conv5(x))) # print(conv.size()) b, c, h, w = conv.size() assert h == 1, "the height of conv must be 1" conv = conv.squeeze(2) conv = conv.permute(2, 0, 1) # [w, b, c] # rnn features output = self.rnn(conv) return output class CRNN_v2(nn.Module): def __init__(self, imgH, nc, nclass, nh, leakyRelu=False): super(CRNN_v2, self).__init__() assert imgH % 16 == 0, 'imgH has to be a multiple of 16' # 1x32x128 self.conv1_1 = nn.Conv2d(nc, 32, 3, 1, 1) self.bn1_1 = nn.BatchNorm2d(32) self.relu1_1 = nn.ReLU(True) self.conv1_2 = nn.Conv2d(32, 64, 3, 1, 1) self.bn1_2 = nn.BatchNorm2d(64) self.relu1_2 = nn.ReLU(True) self.pool1 = nn.MaxPool2d(2, 2) # 64x16x64 self.conv2_1 = nn.Conv2d(64, 64, 3, 1, 1) self.bn2_1 = nn.BatchNorm2d(64) self.relu2_1 = nn.ReLU(True) self.conv2_2 = nn.Conv2d(64, 128, 3, 1, 1) self.bn2_2 = nn.BatchNorm2d(128) self.relu2_2 = nn.ReLU(True) self.pool2 = nn.MaxPool2d(2, 2) # 128x8x32 self.conv3_1 = nn.Conv2d(128, 96, 3, 1, 1) self.bn3_1 = nn.BatchNorm2d(96) self.relu3_1 = nn.ReLU(True) self.conv3_2 = nn.Conv2d(96, 192, 3, 1, 1) self.bn3_2 = nn.BatchNorm2d(192) self.relu3_2 = nn.ReLU(True) self.pool3 = nn.MaxPool2d((2, 2), (2, 1), (0, 1)) # 192x4x32 self.conv4_1 = nn.Conv2d(192, 128, 3, 1, 1) self.bn4_1 = nn.BatchNorm2d(128) self.relu4_1 = nn.ReLU(True) self.conv4_2 = nn.Conv2d(128, 256, 3, 1, 1) self.bn4_2 = nn.BatchNorm2d(256) self.relu4_2 = nn.ReLU(True) self.pool4 = nn.MaxPool2d((2, 2), (2, 1), (0, 1)) # 256x2x32 self.bn5 = nn.BatchNorm2d(256) # 256x2x32 self.rnn = nn.Sequential( BidirectionalLSTM(512, nh, nh), BidirectionalLSTM(nh, nh, nclass)) def forward(self, input): # conv features x = self.pool1(self.relu1_2(self.bn1_2(self.conv1_2(self.relu1_1(self.bn1_1(self.conv1_1(input))))))) x = self.pool2(self.relu2_2(self.bn2_2(self.conv2_2(self.relu2_1(self.bn2_1(self.conv2_1(x))))))) x = self.pool3(self.relu3_2(self.bn3_2(self.conv3_2(self.relu3_1(self.bn3_1(self.conv3_1(x))))))) x = self.pool4(self.relu4_2(self.bn4_2(self.conv4_2(self.relu4_1(self.bn4_1(self.conv4_1(x))))))) conv = self.bn5(x) # print(conv.size()) b, c, h, w = conv.size() assert h == 2, "the height of conv must be 2" conv = conv.reshape([b,c*h,w]) conv = conv.permute(2, 0, 1) # [w, b, c] # rnn features output = self.rnn(conv) return output def conv3x3(nIn, nOut, stride=1): # "3x3 convolution with padding" return nn.Conv2d( nIn, nOut, kernel_size=3, stride=stride, padding=1, bias=False ) class basic_res_block(nn.Module): def __init__(self, nIn, nOut, stride=1, downsample=None): super( basic_res_block, self ).__init__() m = OrderedDict() m['conv1'] = conv3x3( nIn, nOut, stride ) m['bn1'] = nn.BatchNorm2d( nOut ) m['relu1'] = nn.ReLU( inplace=True ) m['conv2'] = conv3x3( nOut, nOut ) m['bn2'] = nn.BatchNorm2d( nOut ) self.group1 = nn.Sequential( m ) self.relu = nn.Sequential( nn.ReLU( inplace=True ) ) self.downsample = downsample def forward(self, x): if self.downsample is not None: residual = self.downsample( x ) else: residual = x out = self.group1( x ) + residual out = self.relu( out ) return out class CRNN_res(nn.Module): def __init__(self, imgH, nc, nclass, nh): super(CRNN_res, self).__init__() assert imgH % 16 == 0, 'imgH has to be a multiple of 16' self.conv1 = nn.Conv2d(nc, 64, 3, 1, 1) self.relu1 = nn.ReLU(True) self.res1 = basic_res_block(64, 64) # 1x32x128 down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(128)) self.res2_1 = basic_res_block( 64, 128, 2, down1 ) self.res2_2 = basic_res_block(128,128) # 64x16x64 down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(256)) self.res3_1 = basic_res_block(128, 256, 2, down2) self.res3_2 = basic_res_block(256, 256) self.res3_3 = basic_res_block(256, 256) # 128x8x32 down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=1, stride=(2, 1), bias=False),nn.BatchNorm2d(512)) self.res4_1 = basic_res_block(256, 512, (2, 1), down3) self.res4_2 = basic_res_block(512, 512) self.res4_3 = basic_res_block(512, 512) # 256x4x16 self.pool = nn.AvgPool2d((2, 2), (2, 1), (0, 1)) # 512x2x16 self.conv5 = nn.Conv2d(512, 512, 2, 1, 0) self.bn5 = nn.BatchNorm2d(512) self.relu5 = nn.ReLU(True) # 512x1x16 self.rnn = nn.Sequential( BidirectionalLSTM(512, nh, nh), BidirectionalLSTM(nh, nh, nclass)) def forward(self, input): # conv features x = self.res1(self.relu1(self.conv1(input))) x = self.res2_2(self.res2_1(x)) x = self.res3_3(self.res3_2(self.res3_1(x))) x = self.res4_3(self.res4_2(self.res4_1(x))) x = self.pool(x) conv = self.relu5(self.bn5(self.conv5(x))) # print(conv.size()) b, c, h, w = conv.size() assert h == 1, "the height of conv must be 1" conv = conv.squeeze(2) conv = conv.permute(2, 0, 1) # [w, b, c] # rnn features output = self.rnn(conv) return output if __name__ == '__main__': pass
Please refer to: https://github.com/breezedeus/cnstd