파이썬 이것저것/파이썬 딥러닝 관련

yolo-nas 학습하기[Python]

agingcurve 2023. 6. 7. 22:22
반응형

https://www.youtube.com/watch?v=V-H3eoPUnA8 

 

2023년 초 roboflow에서 출시한 yolov5의 후속작 yolov8과 

Object Detection분야에서 다시 SOTA를 찍은 yolov6 두가지를 정확도와 속도에서 

더 좋은 결과를 보여주는 yolo-nas를 학습하는 방법을 공부해 보았다. 

https://www.youtube.com/watch?v=91p2SkSuZkc 

 

환경은 코랩에서 진행했다.

https://colab.research.google.com/drive/1yHrHkUR1X2u2FjjvNMfUbSXTkUul6o1P?usp=sharing 

 

DeciYoloCustomDatasetQAFineTuning.ipynb

Colaboratory notebook

colab.research.google.com

 

pytorch 2.0에서 제공해주는 양자화 학습을 지원해주는 코드인데, 제대로 되지 않아서 일반 학습을 사용하여 진행했다.

양자화 학습이 가능하면 재 포스팅 예정이다.

 

 

설치

! pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 &> /dev/null
! pip install pytorch-quantization==2.1.2 --extra-index-url https://pypi.ngc.nvidia.com &> /dev/null
! pip install git+https://github.com/Deci-AI/super-gradients.git@master  --upgrade &> /dev/null

infererence

 

from super_gradients.training.datasets.detection_datasets.coco_format_detection import COCOFormatDetectionDataset
from super_gradients.training.transforms.transforms import DetectionMosaic, DetectionRandomAffine, DetectionHSV, \
    DetectionHorizontalFlip, DetectionPaddedRescale, DetectionStandardize, DetectionTargetsFormatTransform
from super_gradients.training.utils.detection_utils import DetectionCollateFN
from super_gradients.training import dataloaders
from super_gradients.training.datasets.datasets_utils import worker_init_reset_seed
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import DetectionMetrics_050
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
from super_gradients.training import Trainer
from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training.processing import ComposeProcessing

## Step 1: Initialize train and validation dataloaders
trainset = COCOFormatDetectionDataset(data_dir="/content/soccer-players-2/",
                                      images_dir="train",
                                      json_annotation_file="train/_annotations.coco.json",
                                      input_dim=(640, 640),
                                      ignore_empty_annotations=False,
                                      transforms=[
                                          DetectionMosaic(prob=1., input_dim=(640, 640)),
                                          DetectionRandomAffine(degrees=0., scales=(0.5, 1.5), shear=0.,
                                                                target_size=(640, 640),
                                                                filter_box_candidates=False, border_value=128),
                                          DetectionHSV(prob=1., hgain=5, vgain=30, sgain=30),
                                          DetectionHorizontalFlip(prob=0.5),
                                          DetectionPaddedRescale(input_dim=(640, 640), max_targets=300),
                                          DetectionStandardize(max_value=255),
                                          DetectionTargetsFormatTransform(max_targets=300, input_dim=(640, 640),
                                                                          output_format="LABEL_CXCYWH")
                                      ])


valset = COCOFormatDetectionDataset(data_dir="/content/soccer-players-2/",
                                    images_dir="valid",
                                    json_annotation_file="valid/_annotations.coco.json",
                                    input_dim=(640, 640),
                                    ignore_empty_annotations=False,
                                    transforms=[
                                        DetectionPaddedRescale(input_dim=(640, 640), max_targets=300),
                                        DetectionStandardize(max_value=255),
                                        DetectionTargetsFormatTransform(max_targets=300, input_dim=(640, 640),
                                                                        output_format="LABEL_CXCYWH")
                                    ])

train_loader = dataloaders.get(dataset=trainset, dataloader_params={
    "shuffle": True,
    "batch_size": 16,
    "drop_last": False,
    "pin_memory": True,
    "collate_fn": DetectionCollateFN(),
    "worker_init_fn": worker_init_reset_seed,
    "min_samples": 512
})

valid_loader = dataloaders.get(dataset=valset, dataloader_params={
    "shuffle": False,
    "batch_size": 32,
    "num_workers": 2,
    "drop_last": False,
    "pin_memory": True,
    "collate_fn": DetectionCollateFN(),
    "worker_init_fn": worker_init_reset_seed
})

## Step 2: Defining training hyperparameters
train_params = {
    "warmup_initial_lr": 1e-6,
    "initial_lr": 5e-4,
    "lr_mode": "cosine",
    "cosine_final_lr_ratio": 0.1,
    "optimizer": "AdamW",
    "zero_weight_decay_on_bias_and_bn": True,
    "lr_warmup_epochs": 3,
    "warmup_mode": "linear_epoch_step",
    "optimizer_params": {"weight_decay": 0.0001},
    "ema": True,
    "ema_params": {"decay": 0.9, "decay_type": "threshold"},
    "max_epochs": 10,
    "mixed_precision": True,
    "loss": PPYoloELoss(use_static_assigner=False, num_classes=4, reg_max=16),
    "valid_metrics_list": [
        DetectionMetrics_050(score_thres=0.1, top_k_predictions=300, num_cls=4, normalize_targets=True,
                             post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.01,
                                                                                    nms_top_k=1000, max_predictions=300,
                                                                                    nms_threshold=0.7))],

    "metric_to_watch": 'mAP@0.50'}
    
trainer = Trainer(experiment_name="yolo_nas_s_soccer_players", ckpt_root_dir="/content/sg_checkpoints_dir/")
net = models.get(Models.YOLO_NAS_S, num_classes=4, pretrained_weights="coco")
trainer.train(model=net, training_params=train_params, train_loader=train_loader, valid_loader=valid_loader)

 

 

코랩환경에서 해당 코드를 실행 하다보면

AttributeError: module 'collections' has no attribute 'Iterable'

 

해당 오류를 마주쳤다.

 

해당 오류뜨는 이유는 python3.10에서 collecetion.callable 참조가 파이썬 3.10부터 collecetion.abc.callable로 이동되어서 나타나는 오류이다.

 

파이썬 3.9버전이하 낮추는 방법이 있고 코드를 직접 수정하는 방법이 있다.

해당 파이썬 파일을 수정하였다.

 

해당 collection를 지워주었다. (기존에 typing 모듈의 Iterable이 설치되어 있으니, 자동으로 typing 모듈로 변경될 것이다.)

 

 

이후 런타임 재시작 버튼을 눌러준다.

 

정상 작동한다

 

학습 후, inference를 해보았다.

import torchvision.transforms as transforms
import torch
from super_gradients.training import models
import os
from super_gradients.common.object_names import Models
from super_gradients.training import Trainer
from glob import glob
from PIL import Image
from io import BytesIO
from super_gradients.training.models.detection_models.yolo_base import YoloPostPredictionCallback
import numpy as np
# preprocess = transforms.Compose([
#     transforms.Resize([812, 812]),
#     transforms.PILToTensor()
# ])

## best.pt 있는 경로 지정
trainer = Trainer(experiment_name="yolo_nas_s_fake1", ckpt_root_dir="/yolo_nas/model/")
net = models.get(Models.YOLO_NAS_S, num_classes=34, checkpoint_path=os.path.join(trainer.checkpoints_dir_path, "ckpt_best.pth"))

# image = Image.open('/yolo_nas/volume/test/064442001.png').resize((1024, 1024))
# image_tensor = torch.tensor(np.array(image)[:, :, ::-1].copy()).permute(2, 0, 1).unsqueeze(dim=0).float().cuda()
# model.eval()
# net.eval()
# with torch.no_grad():
#     raw_predictions = net(image_tensor)

# predictions = YoloPostPredictionCallback(conf=0.1, iou=0.3)(raw_predictions)[0].cpu().numpy()
## test 폴더 지정
ann_path = '/yolo_nas/volume/test/*.png'

ann_files = glob(ann_path)
cnt = 1
for img in ann_files:
    prediction = net.predict(img)
    prediction_objects = list(prediction._images_prediction_lst)[0]
    bboxes, classes, conf = prediction_objects.prediction.bboxes_xyxy, prediction_objects.prediction.labels, prediction_objects.prediction.confidence
    print(bboxes, classes, conf)

bboxes, classes, conf 값을 inference에서 가져올 수 있다.