What are the specifics of 'torch.save()'?

Code Structure in train.py
model.train()
Model Traing Code
model.eval()
Training accuracy and prediction
torch.save()

I use train.py to train and save my model. Later, I continue training and make predictions with the saved model. Does torch.save() preserve the train() or eval() state of the model? I’m new to this, so any clarification would be helpful.

2 Likes

“It’s not critical, but remember: torch.save saves the model state (both, actually). After using torch.load, you can simply invoke model.train() or model.eval() to set it to the desired state.”

2 Likes
import torch

# Save the model
torch.save(model, 'model.pth')

# Load the model
model = torch.load('model.pth')

# Set the model mode for predictions
model.eval()

# Or set it to training mode to continue training
model.train()

Alternatively, saving and loading the model’s state dictionary (model.state_dict()) provides more flexibility:

import torch

Save the model state dictionary

torch.save(model.state_dict(), ‘model_state.pth’)

Load the model state dictionary

model.load_state_dict(torch.load(‘model_state.pth’))

Set the model mode for predictions

model.eval()

Or set it to training mode to continue training

model.train()


In general, this approach is advised for maintaining the model's mode and parameters.
1 Like