How to extract loss and accuracy from logger by each epoch in pytorch lightning?
Asked Answered
D

2

9

I want to extract all data to make the plot, not with tensorboard. My understanding is all log with loss and accuracy is stored in a defined directory since tensorboard draw the line graph.

%reload_ext tensorboard
%tensorboard --logdir lightning_logs/

enter image description here

However, I wonder how all log can be extracted from the logger in pytorch lightning. The next is the code example in training part.

#model
ssl_classifier = SSLImageClassifier(lr=lr)

#train
logger = pl.loggers.TensorBoardLogger(name=f'ssl-{lr}-{num_epoch}', save_dir='lightning_logs')

trainer = pl.Trainer(progress_bar_refresh_rate=20,
                            gpus=1,
                            max_epochs = max_epoch,
                            logger = logger,
                            )

trainer.fit(ssl_classifier, train_loader, val_loader)

I had confirmed that trainer.logger.log_dir returned directory which seems to save logs and trainer.logger.log_metrics returned <bound method TensorBoardLogger.log_metrics of <pytorch_lightning.loggers.tensorboard.TensorBoardLogger object at 0x7efcb89a3e50>>.

trainer.logged_metrics returned only the log in the final epoch, like

{'epoch': 19,
 'train_acc': tensor(1.),
 'train_loss': tensor(0.1038),
 'val_acc': 0.6499999761581421,
 'val_loss': 1.2171183824539185}

Do you know how to solve the situation?

Drumm answered 22/9, 2021 at 0:44 Comment(0)
R
4

Lightning do not store all logs by itself. All it does is streams them into the logger instance and the logger decides what to do.

The best way to retrieve all logged metrics is by having a custom callback:

class MetricTracker(Callback):

  def __init__(self):
    self.collection = []

  def on_validation_batch_end(trainer, module, outputs, ...):
    vacc = outputs['val_acc'] # you can access them here
    self.collection.append(vacc) # track them

  def on_validation_epoch_end(trainer, module):
    elogs = trainer.logged_metrics # access it here
    self.collection.append(elogs)
    # do whatever is needed

You can then access all logged stuff from the callback instance

cb = MetricTracker()
Trainer(callbacks=[cb])

cb.collection # do you plotting and stuff
Rockel answered 22/9, 2021 at 23:47 Comment(0)
A
4

The accepted answer is not fundamentally wrong but does not follow the official (current) guidelines by Pytorch-Lightning.

As suggested here: https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#make-a-custom-logger

It is suggested to write a class like:

from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment


class MyLogger(LightningLoggerBase):
    @property
    def name(self):
        return "MyLogger"

    @property
    @rank_zero_experiment
    def experiment(self):
        # Return the experiment object associated with this logger.
        pass

    @property
    def version(self):
        # Return the experiment version, int or str.
        return "0.1"

    @rank_zero_only
    def log_hyperparams(self, params):
        # params is an argparse.Namespace
        # your code to record hyperparameters goes here
        pass

    @rank_zero_only
    def log_metrics(self, metrics, step):
        # metrics is a dictionary of metric names and values
        # your code to record metrics goes here
        pass

    @rank_zero_only
    def save(self):
        # Optional. Any code necessary to save logger data goes here
        # If you implement this, remember to call `super().save()`
        # at the start of the method (important for aggregation of metrics)
        super().save()

    @rank_zero_only
    def finalize(self, status):
        # Optional. Any code that needs to be run after training
        # finishes goes here
        pass

By looking inside the class LightningLoggerBase, one can see some suggestions of function that could be overriden.

Here is a minimalistic loggers of mine. It is highly not optimised, but would be a good first shot. I will edit if I improved it.

import collections

from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only

class History_dict(LightningLoggerBase):
    def __init__(self):
        super().__init__()

        self.history = collections.defaultdict(list) # copy not necessary here  
        # The defaultdict in contrast will simply create any items that you try to access

    @property
    def name(self):
        return "Logger_custom_plot"

    @property
    def version(self):
        return "1.0"

    @property
    @rank_zero_experiment
    def experiment(self):
        # Return the experiment object associated with this logger.
        pass

@rank_zero_only
def log_metrics(self, metrics, step):
    # metrics is a dictionary of metric names and values
    # your code to record metrics goes here
    for metric_name, metric_value in metrics.items():
        if metric_name != 'epoch':
            self.history[metric_name].append(metric_value)
        else: # case epoch. We want to avoid adding multiple times the same. It happens for multiple losses.
            if (not len(self.history['epoch']) or    # len == 0:
                not self.history['epoch'][-1] == metric_value) : # the last values of epochs is not the one we are currently trying to add.
                self.history['epoch'].append(metric_value)
            else:
                pass
    return

    def log_hyperparams(self, params):
        pass
Acro answered 3/12, 2021 at 19:36 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.