timm
/

Image Classification
timm
PyTorch
Safetensors

The model weight seems broken

#1
by Dianyo - opened

Hi,
I've downloaded this model and run the model on ImageNet-1K dataset, and got very bad performance on it. Here's my script for running validation

def validate(config, data_loader, model, is_timm_model):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()

    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()
    after_warmup_batch_time = AverageMeter()

    end = time.time()
    for idx, (images, target) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        end = time.time()
        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
            if is_timm_model:
                output = model(images)
            else:
                output = model(images, return_dict=True).logits

        loss = criterion(output, target)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc5_meter.update(acc5.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
        # if idx % 5 == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            logger.info(
                f'Test: [{idx}/{len(data_loader)}]\t'
                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
                f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
                f'Mem {memory_used:.0f}MB')
    end = time.time()
    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg

Here's the truncated output:

......
[2024-11-04 07:12:04 tf_efficientnetv2_s.in1k](main_baseline.py 62): INFO Test: [340/391]       Time 0.039 (0.042)      Loss 4.8750 (4.8620)  Acc@1 23.438 (24.931)   Acc@5 40.625 (40.394)   Mem 683MB
[2024-11-04 07:12:05 tf_efficientnetv2_s.in1k](main_baseline.py 62): INFO Test: [350/391]       Time 0.039 (0.042)      Loss 5.0469 (4.8623)  Acc@1 24.219 (24.922)   Acc@5 41.406 (40.389)   Mem 683MB
[2024-11-04 07:12:06 tf_efficientnetv2_s.in1k](main_baseline.py 62): INFO Test: [360/391]       Time 0.039 (0.042)      Loss 4.9336 (4.8643)  Acc@1 30.469 (24.961)   Acc@5 39.062 (40.357)   Mem 683MB
[2024-11-04 07:12:07 tf_efficientnetv2_s.in1k](main_baseline.py 62): INFO Test: [370/391]       Time 0.039 (0.042)      Loss 4.4883 (4.8630)  Acc@1 27.344 (24.979)   Acc@5 45.312 (40.347)   Mem 683MB
[2024-11-04 07:12:08 tf_efficientnetv2_s.in1k](main_baseline.py 62): INFO Test: [380/391]       Time 0.039 (0.042)      Loss 5.0078 (4.8639)  Acc@1 25.000 (24.984)   Acc@5 37.500 (40.332)   Mem 683MB
[2024-11-04 07:12:09 tf_efficientnetv2_s.in1k](main_baseline.py 62): INFO Test: [390/391]       Time 0.063 (0.042)      Loss 5.2383 (4.8631)  Acc@1 22.500 (24.984)   Acc@5 36.250 (40.366)   Mem 683MB
[2024-11-04 07:12:10 tf_efficientnetv2_s.in1k](main_baseline.py 70): INFO  * Acc@1 24.984 Acc@5 40.366
PyTorch Image Models org
edited 8 days ago

@Dianyo it is fine, the timm validation.py script is the reference (https://github.com/huggingface/pytorch-image-models/blob/main/validate.py), you probably used incorrect preprocessing.

rwightman changed discussion status to closed

Sign up or log in to comment