Having previously conducted pre-training for a Tabular BERT model with SageMaker Training Jobs, I proceeded to create a job for fine-tuning using the model.tar.gz
file from the pre-training results.
Although the job is fundamentally similar to the pre-training process, certain adjustments were required in the following aspects, which I have compiled as a memo:
- Extracting the tar file
- Switching arguments between local environment and SageMaker using environment variables
Extracting the tar file
The pre-training job has saved the model.tar.gz
file on S3.
The model.tar.gz
file, which is stored on S3 as a result of the pre-training job, contains the following files: the model pytorch_model.bin
, the configuration file config.json
, the dictionary file vocab.nb
, and the token-to-id conversion file vocab_token2id.bin
. To load these files during fine-tuning, it is essential to devise a method for extracting the tar file upon job execution.
Initially, set the S3 path for the model.tar.gz
file in the input_model section of the job file. Consequently, the model.tar.gz
file will be placed in the /opt/ml/input/data/input_model/
(model_path) directory when the job is executed.
import sagemaker
from sagemaker.estimator import Estimator
session = sagemaker.Session()
role = sagemaker.get_execution_role()
estimator = Estimator(
image_uri=<image-url>,
role=role,
instance_type="ml.g4dn.2xlarge",
instance_count=1,
base_job_name="tabformer-opt-fine-tuning",
output_path="s3://<bucket-name>/sagemaker/output_data/fine_tuning",
code_location="s3://<bucket-name>/sagemaker/output_data/fine_tuning",
sagemaker_session=session,
entry_point="fine-tuning.sh",
dependencies=["tabformer-opt"],
hyperparameters={
"data_root": "/opt/ml/input/data/input_data/",
"data_fname": "summary",
"output_dir": "/opt/ml/model/",
"model_path": "/opt/ml/input/data/input_model/",
}
)
estimator.fit({
"input_data": "s3://<bucket-name>/sagemaker/input_data/summary.csv",
"input_model": "s3://<bucket-name>/sagemaker/output_data/pre_training/tabformer-opt-2022-12-16-07-00-45-931/output/model.tar.gz"
})
Next, include the following in the fine-tuning execution file tabformer_bert_fine_tuning.py
:
with tarfile.open(name=path.join(args.model_path, f'model.tar.gz'), mode="r:gz") as mytar:
mytar.extractall(path.join(args.model_path, f'model'))
token2id_file = path.join(args.model_path, f"model/vocab_token2id.bin")
vocab_file = path.join(args.model_path, f"model/vocab.nb")
pretrained_model = path.join(args.model_path, f"model/checkpoint-500/pytorch_model.bin")
pretrained_config = path.join(args.model_path, f"model/checkpoint-500/config.json")
The tarfile.open()
function reads the model.tar.gz
file, and mytar.extractall(path.join(args.model_path, f'model'))
extracts the contents under the /opt/ml/input/data/input_model/model/
directory.
This allows you to load the extracted files, such as with token2id_file = path.join(args.model_path, f"model/vocab_token2id.bin")
.
Switching arguments between local environment and SageMaker using environment variables
With this setup, you can now load the model.tar.gz
file from S3. However, there may be cases where you want to change the source of the file when performing fine-tuning locally.
To handle such situations, you can use os.getenv('SM_MODEL_DIR')
to obtain the SageMaker environment variable SM_MODEL_DIR
(the directory that will be uploaded to S3 upon container termination) and switch the source of the file between local and SageMaker (Job) environments.
key = os.getenv('SM_MODEL_DIR')
if key :
with tarfile.open(name=path.join(args.model_path, f'model.tar.gz'), mode="r:gz") as mytar:
mytar.extractall(path.join(args.model_path, f'model'))
token2id_file = path.join(args.model_path, f"model/vocab_token2id.bin")
vocab_file = path.join(args.model_path, f"model/vocab.nb")
pretrained_model = path.join(args.model_path, f"model/checkpoint-500/pytorch_model.bin")
pretrained_config = path.join(args.model_path, f"model/checkpoint-500/config.json")
else :
vocab_file = path.join(args.model_path, f"vocab.nb")
token2id_file = path.join(args.model_path, f"vocab_token2id.bin")
pretrained_model = path.join(args.model_path, f"checkpoint-500/pytorch_model.bin")
pretrained_config = path.join(args.model_path, f"checkpoint-500/config.json")
Top comments (0)