
今天一个朋友用YOLO4预测图片报错:size mismatch for yolo_head2.1.bias: copying a param with shape torch.Size(【75】) from checkpoint, the shape in current model is torch.Size(【18】).


def make_dataset(dir, opt):images = []assert os.path.isdir(dir), '%s is not a valid directory' % dirfileList = sorted(os.walk(dir))    for root, _, fnames in fileList:for fname in fnames:if is_image_file(fname):path = os.path.join(root, fname)if ((opt.phase=='test') or (opt.phase=='train') and min(Image.open(path).size) >= 512):images.append(path)        return images




import PIL.Image as Image
infile = 'images/train/202108315.jpg'
outfile = 'images/train/1202108315.jpg'
im = Image.open(infile)
(x, y) = im.size
x_s = 500
y_s = int(y * x_s / x)
out = im.resize((x_s, y_s), Image.ANTIALIAS)
out.save(outfile)print('infile size: ', x, y)
print('outfile size: ', x_s, y_s)
# -*- coding: utf-8 -*-
import numpy as np
import codecs
import json
from glob import glob
import cv2
# 1.标签路径
labelme_path = "labelme_json/"
# 保存路径
isUseTest = True  # 是否创建test集
# # 2.创建要求文件夹
# if not os.path.exists("Annotations"):
#     os.makedirs("Annotations")
# if not os.path.exists("JPEGImages/"):
#     os.makedirs("JPEGImages/")
# if not os.path.exists("ImageSets/Main/"):
#     os.makedirs("ImageSets/Main/")
# 3.获取待处理文件
files = glob(labelme_path + "*.json")
files = [i.replace("\\", "/").split("/")[-1].split(".json")[0] for i in files]
# 4.读取标注信息并写入 xml
for json_file_ in files:json_filename = labelme_path + json_file_ + ".json"json_file = json.load(open(json_filename, "r", encoding="utf-8"))height, width, channels = cv2.imread('labelme_json/' + json_file_ + ".jpg").shapewith codecs.open("new_xml/" + json_file_ + ".xml", "w", "utf-8") as xml:xml.write('<annotation>\n')xml.write('\t<folder>' + 'WH_data' + '</folder>\n')xml.write('\t<filename>' + json_file_ + ".jpg" + '</filename>\n')xml.write('\t<source>\n')xml.write('\t\t<database>WH Data</database>\n')xml.write('\t\t<annotation>WH</annotation>\n')xml.write('\t\t<image>flickr</image>\n')xml.write('\t\t<flickrid>NULL</flickrid>\n')xml.write('\t</source>\n')xml.write('\t<owner>\n')xml.write('\t\t<flickrid>NULL</flickrid>\n')xml.write('\t\t<name>WH</name>\n')xml.write('\t</owner>\n')xml.write('\t<size>\n')xml.write('\t\t<width>' + str(width) + '</width>\n')xml.write('\t\t<height>' + str(height) + '</height>\n')xml.write('\t\t<depth>' + str(channels) + '</depth>\n')xml.write('\t</size>\n')xml.write('\t\t<segmented>0</segmented>\n')for multi in json_file["shapes"]:points = np.array(multi["points"])labelName = multi["label"]xmin = min(points[:, 0])xmax = max(points[:, 0])ymin = min(points[:, 1])ymax = max(points[:, 1])label = multi["label"]if xmax <= xmin:passelif ymax <= ymin:passelse:xml.write('\t<object>\n')xml.write('\t\t<name>' + labelName + '</name>\n')xml.write('\t\t<pose>Unspecified</pose>\n')xml.write('\t\t<truncated>1</truncated>\n')xml.write('\t\t<difficult>0</difficult>\n')xml.write('\t\t<bndbox>\n')xml.write('\t\t\t<xmin>' + str(int(xmin)) + '</xmin>\n')xml.write('\t\t\t<ymin>' + str(int(ymin)) + '</ymin>\n')xml.write('\t\t\t<xmax>' + str(int(xmax)) + '</xmax>\n')xml.write('\t\t\t<ymax>' + str(int(ymax)) + '</ymax>\n')xml.write('\t\t</bndbox>\n')xml.write('\t</object>\n')print(json_filename, xmin, ymin, xmax, ymax, label)xml.write('</annotation>')




