I am trying to implement a classification head for the reformer transformer. The classification head works fine, but when I try to change one of the config parameters- config.axial_pos_shape i.e sequence length parameter for the model it throws an error;
size mismatch for reformer.embeddings.position_embeddings.weights.0: copying a param with shape torch.Size([512, 1, 64]) from checkpoint, the shape in current model is torch.Size([64, 1, 64]). size mismatch for reformer.embeddings.position_embeddings.weights.1: copying a param with shape torch.Size([1, 1024, 192]) from checkpoint, the shape in current model is torch.Size([1, 128, 192]).
The config:
{
"architectures": [
"ReformerForSequenceClassification"
],
"attention_head_size": 64,
"attention_probs_dropout_prob": 0.1,
"attn_layers": [
"local",
"lsh",
"local",
"lsh",
"local",
"lsh"
],
"axial_norm_std": 1.0,
"axial_pos_embds": true,
"axial_pos_embds_dim": [
64,
192
],
"axial_pos_shape": [
64,
256
],
"chunk_size_feed_forward": 0,
"chunk_size_lm_head": 0,
"eos_token_id": 2,
"feed_forward_size": 512,
"hash_seed": null,
"hidden_act": "relu",
"hidden_dropout_prob": 0.05,
"hidden_size": 256,
"initializer_range": 0.02,
"intermediate_size": 3072,
"is_decoder": true,
"layer_norm_eps": 1e-12,
"local_attention_probs_dropout_prob": 0.05,
"local_attn_chunk_length": 64,
"local_num_chunks_after": 0,
"local_num_chunks_before": 1,
"lsh_attention_probs_dropout_prob": 0.0,
"lsh_attn_chunk_length": 64,
"lsh_num_chunks_after": 0,
"lsh_num_chunks_before": 1,
"max_position_embeddings": 8192,
"model_type": "reformer",
"num_attention_heads": 2,
"num_buckets": [
64,
128
],
"num_chunks_after": 0,
"num_chunks_before": 1,
"num_hashes": 1,
"num_hidden_layers": 6,
"output_past": true,
"pad_token_id": 0,
"task_specific_params": {
"text-generation": {
"do_sample": true,
"max_length": 100
}
},
"vocab_size": 320
}
Python Code:
config = ReformerConfig()
config.max_position_embeddings = 8192
config.axial_pos_shape=[64, 128]
#config = ReformerConfig.from_pretrained('./cnp/config.json', output_attention=True)
model = ReformerForSequenceClassification(config)
model.load_state_dict(torch.load("./cnp/pytorch_model.bin"))