Predict with a Pre-Trained Model

We can use a saved model (like the one from the previous page), to continue training or for prediction. This page of the MXNet tutorial explains how to predict new examples from a pretrained model.

We need several modules:

use AI::MXNet qw(mx);

use AI::MXNet::Gluon qw(gluon);

use AI::MXNet::Gluon::NN qw(nn);

use AI::MXNet::AutoGrad qw(autograd);

use AI::MXNet::Base;

And a future version of this page will require:

use PDL::Graphics::Gnuplot;

so that we humans can see the images that our computer is seeing. In the meantime, a quick way to examine a single image in PDL would be:

pdl> $example = ${$mnist_valid->data}[$i]->reshape([1,28,28]);


pdl> gplot( with=>'image', $example->aspdl->​slice(':,27:0,0'));

where slice(':,27:0,0') flips the image vertically to its correct orientation.

prerequisites

The example above assumes that you have loaded the validation data as $mnist_valid like we did on the previous page, so let's start by bringing in what we need from that page. First, we need the validation data:

my $mnist_valid = gluon->data->vision->​FashionMNIST(
root=>'./data/fashion-mnist', train=>0, transform=>\&transformer);

my $valid_data = gluon->data->DataLoader(
$mnist_valid, batch_size=>1, shuffle=>0);

and we need the transformer subroutine.

sub transformer {
my ($data, $label) = @_;
$data = $data->reshape([1,28,28]);
$data = $data->astype( 'float32')/255;
$data = ( $data - 0.31 ) / 0.31;
return( $data , $label);

}

We also need the network:

my $net = nn->Sequential();

$net->name_scope(sub {
$net->add(nn->Conv2D(​channels=>6, kernel_size=>5, activation=>'relu'));
$net->add(nn->MaxPool2D(​pool_size=>2, strides=>2));
$net->add(nn->Conv2D(​channels=>16, kernel_size=>3, activation=>'relu'));
$net->add(nn->MaxPool2D(​pool_size=>2, strides=>2));
$net->add(nn->Flatten());
$net->add(nn->Dense(120, activation=>"relu"));
$net->add(nn->Dense(84, activation=>"relu"));
$net->add(nn->Dense(10));
});

And we need the parameters of the model that we trained:

my $param_file = './data/params/fashion-mnist.params';

$net->load_parameters(​$param_file);

predict

Our next step is to compare the model's predictions with the correct labels. So let's give names to those labels:

my @text_labels = ('t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot');

Now, let's take the first ten images and labels from the validation set:

my @tendl = @{$valid_data}[0..9];

and compare:

my $topline;

$topline .= ' PREDICTION :: CORRECT'."\n";

$topline .= ' ========== :: ======='."\n";

print $topline;


for my $i (0..$#tendl) {
my $data = ${$tendl[$i]}[0];
my $label = ${$tendl[$i]}[1];

my $ot = $net->($data)->argmax({axis=>1});
my $pred = $text_labels[ PDL::sclr( $ot->aspdl )];
my $true = $text_labels[ PDL::sclr( $label->aspdl )];

my $otline;
$otline .= sprintf("%12s",$pred) ." :: ";
$otline .= sprintf("%-10s",$true)." ";
$otline .= ( $pred eq $true ) ? ".." : "XX";

print $otline ."\n";

}

which prints:


PREDICTION :: CORRECT
========== :: =======
   t-shirt :: t-shirt    ..
   trouser :: trouser    ..
  pullover :: pullover   ..
  pullover :: pullover   ..
     dress :: dress      ..
     shirt :: pullover   XX
       bag :: bag        ..
     shirt :: shirt      ..
    sandal :: sandal     ..
   t-shirt :: t-shirt    ..

The model correctly predicted nine out of ten.

Gluon model zoo

In practice, we do not always need to train large-scale models from scratch. The Gluon model zoo provides many pre-trained models. Similarly, the Perl script for this section was not written from scratch. Sergey Kolychev provides some examples.

So let's use the ModelZoo:

use AI::MXNet::Gluon::ModelZoo 'get_model';

use AI::MXNet::Gluon::Utils 'download';

And let's work with the ResNet-152 V2 model which was trained on the ImageNet dataset:

my $model = 'resnet152_v2';

my $net = get_model($model, pretrained=>1);

We also need the text labels for each class:

my $fname = download('​http://data.mxnet.io/models/imagenet/synset.txt');

my @text_labels = map { chomp; s/^\S+\s+//; $_ } IO::File->new(​$fname)->getlines;

And we need the image:

my $image = 'kyuubi.jpg';

To prepare the image, we first resize the short edge to 256 pixels. Then we center crop it to a square 224-pixel image.

$image = mx->image->imread(​$image);

$image = mx->image->​resize_short($image, $model =~ /inception/ ? 330 : 256);

($image) = mx->image->​center_crop($image, [($model =~ /inception/ ? 299 : 224)x2]);

Because PDL is column-major, so is the image. Let's transpose:

$image = $image->transpose(​[2,0,1])->expand_dims(​axis=>0);

And we normalize each color channel, by subtracting off the means and dividing by the standard deviations.

my $rgb_mean = nd->array(​[0.485, 0.456, 0.406])->reshape([1,3,1,1]);

my $rgb_std = nd->array(​[0.229, 0.224, 0.225])->reshape([1,3,1,1]);

$image = ($image->astype(​'float32') / 255 - $rgb_mean) / $rgb_std;

Finally, we try to recognize the object in the image. To obtain probability scores, we perform one more softmax on the output and then print the top-5 recognized objects.

my $prob = $net->($image)->softmax;

for my $idx (@{ $prob->topk(k=>5)->at(0) }) {
my $i = $idx->asscalar;
printf(
  "With prob = %.5f, it contains %s\n",
  $prob->at(0)->at($i)->​asscalar, $text_labels[$i]
);

}

Copyright © 2002-2024 Eryk Wdowiak