PyTorch TransformerEncoder Reconstruction Error Anomaly Detection for Ordered Data

A fairly well known anomaly detection technique uses a neural encoder-decoder (aka autoencoder) combined with reconstruction error. A few weeks ago, I experimented by inserting a TransformerEncoder module into such a system and the results seem promising.

However, transformer architecture is really designed for input vectors that have an inherent ordering — typically sentences. So, I created some synthetic medical data that has order. I made synthetic patient data that looks like:

0.1668, 0.2881, 0.1000, 0.4209, 0.2587, 0.6369, 0.5745, 0.6382, 0.4587, 0.3155, 0.1677, 0.3741
0.0818, 0.3512, 0.1110, 0.5682, 0.3669, 0.8235, 0.5562, 0.5792, 0.6203, 0.4873, 0.1254, 0.3769
0.3506, 0.3578, 0.1340, 0.3156, 0.2679, 0.9513, 0.5393, 0.6684, 0.6832, 0.3133, 0.2768, 0.2262
. . .

Each line of of the 200-item dataset represents a patient. The 12 values on each line are some sort of hypothetical reading taken every hour for 12 hours (or every 2 hours for 24 hours, etc.) The idea of using synthetic medical data came from my colleague Paige R.

Next, I put together a PyTorch program to create an encoder-decoder network that predicts its input. Data item that aren’t reconstructed closely are anomalies, at least according to the model.

The heart of the program is:

class Transformer_Net(T.nn.Module):
  def __init__(self):
    # 12 numeric inputs: no exact word embedding equivalent
    # pseudo embed_dim = 4
    # seq_len = 12
    super(Transformer_Net, self).__init__()

    self.fc1 = T.nn.Linear(12, 12*4)  # pseudo-embedding

    self.pos_enc = \
      PositionalEncoding(4, dropout=0.00)  # positional

    self.enc_layer = T.nn.TransformerEncoderLayer(d_model=4,
      nhead=2, dim_feedforward=100, 
      batch_first=True)  # d_model divisible by nhead

    self.trans_enc = T.nn.TransformerEncoder(self.enc_layer,
      num_layers=6)

    self.dec1 = T.nn.Linear(48, 18)
    self.dec2 = T.nn.Linear(18, 12)

    # use default weight initialization

  def forward(self, x):
    # x is Size([bs, 12])
    z = T.tanh(self.fc1(x))   # [bs, 48]
    z = z.reshape(-1, 12, 4)  # [bs, 12, 4] 
    z = self.pos_enc(z)       # [bs, 12, 4]
    z = self.trans_enc(z)     # [bs, 12, 4]

    z = z.reshape(-1, 48)              # [bs, 48]
    z = T.tanh(self.dec1(z))           # [bs, 18]
    z = self.dec2(z)  # no activation  # [bs, 12]
  
    return z

The architecture is very complicated. Briefly, each numeric input is mapped to a pseudo-embedding vector with 4 values. Then positional encoding is added so the transformer knows the order of the inputs. The data is converted to 3D to accommodate the TransformerEncoder requirement. The output of the TransformerEncoder is reshaped back to 2D and then fed to two Linear fully connected layers, designed so that the final output shape matches the input shape. Whew!

One architecture alternative I want to explore concerns the numeric embedding where each input reading maps to four values. My implementation really isn’t an embedding because I use a Linear layer, which is fully connected. I want to try a true embedding layer. See https://jamesmccaffrey.wordpress.com/2023/04/20/anomaly-detection-for-tabular-data-using-a-pytorch-transformer-with-numeric-embedding/.

After the model has been trained, I invoke an analyze() function that feeds each of the 200 data items to the model, fetches the output, and measures the difference between input and output. I used a custom error function that is the normalized sum of squared differences — close to but not quite Euclidean distance.

The result looks like:

Analyzing data for largest reconstruction error

Largest reconstruction idx: [140]

Largest reconstruction item:
[ 0.0362  0.0516  0.1421  0.3691  0.2506  0.9113
  0.5158  0.5966  0.6516  0.4894  0.2422  0.4905]

Largest reconstruction error: 0.0248

Its reconstruction =
[ 0.1870  0.2014  0.3200  0.5255  0.4023  0.7735
  0.6971  0.7262  0.4979  0.2906  0.2078  0.2887]

This technique seems very promising, but there are a lot of questions that need to be explored.



Putting together a PyTorch program is like putting together a jigsaw puzzle — it’s difficult to make all the pieces fit together. Jigsaw puzzle manufacturers use the same cutting template for different puzzle images. This means you can combine jigsaw puzzles if you have a lot of patience.


Demo program.

# medical_trans_anomaly.py
# Transformer based reconstruction error anomaly detection
# PyTorch 2.2.1-CPU Anaconda3-2023.09-0  Python 3.11.5
# Windows 10/11

import numpy as np
import torch as T

device = T.device('cpu') 
T.set_num_threads(1)

# -----------------------------------------------------------

class PatientDataset(T.utils.data.Dataset):
  # 12 columns
  def __init__(self, src_file):
    tmp_x = np.loadtxt(src_file, usecols=range(0,12),
      delimiter=",", comments="#", dtype=np.float32)
    self.x_data = T.tensor(tmp_x, dtype=T.float32).to(device)

  def __len__(self):
    return len(self.x_data)

  def __getitem__(self, idx):
    preds = self.x_data[idx, :]  # row idx, all cols
    sample = { 'predictors' : preds }  # as Dictionary
    return sample  

# -----------------------------------------------------------

class PositionalEncoding(T.nn.Module):  # documentation code
  def __init__(self, d_model: int, dropout: float=0.1,
   max_len: int=5000):
    super(PositionalEncoding, self).__init__()  # old syntax
    self.dropout = T.nn.Dropout(p=dropout)
    pe = T.zeros(max_len, d_model)  # like 10x4
    position = \
      T.arange(0, max_len, dtype=T.float).unsqueeze(1)
    div_term = T.exp(T.arange(0, d_model, 2).float() * \
      (-np.log(10_000.0) / d_model))
    pe[:, 0::2] = T.sin(position * div_term)
    pe[:, 1::2] = T.cos(position * div_term)
    pe = pe.unsqueeze(0).transpose(0, 1)
    self.register_buffer('pe', pe)  # allows state-save

  def forward(self, x):
    x = x + self.pe[:x.size(0), :]
    return self.dropout(x)

# -----------------------------------------------------------

class Transformer_Net(T.nn.Module):
  def __init__(self):
    # 12 numeric inputs: no exact word embedding equivalent
    # pseudo embed_dim = 4
    # seq_len = 12
    super(Transformer_Net, self).__init__()

    self.fc1 = T.nn.Linear(12, 12*4)  # pseudo-embedding

    self.pos_enc = \
      PositionalEncoding(4, dropout=0.00)  # positional

    self.enc_layer = T.nn.TransformerEncoderLayer(d_model=4,
      nhead=2, dim_feedforward=100, 
      batch_first=True)  # d_model divisible by nhead

    self.trans_enc = T.nn.TransformerEncoder(self.enc_layer,
      num_layers=6)

    self.dec1 = T.nn.Linear(48, 18)
    self.dec2 = T.nn.Linear(18, 12)

    # use default weight initialization

  def forward(self, x):
    # x is Size([bs, 12])
    z = T.tanh(self.fc1(x))   # [bs, 48]
    z = z.reshape(-1, 12, 4)  # [bs, 12, 4] 
    z = self.pos_enc(z)       # [bs, 12, 4]
    z = self.trans_enc(z)     # [bs, 12, 4]

    z = z.reshape(-1, 48)              # [bs, 48]
    z = T.tanh(self.dec1(z))           # [bs, 18]
    z = self.dec2(z)  # no activation  # [bs, 12]
  
    return z

# -----------------------------------------------------------

def analyze_error(model, ds):
  largest_err = 0.0
  worst_x = None
  worst_y = None
  worst_idx = 0
  n_features = len(ds[0]['predictors'])

  for i in range(len(ds)):
    X = ds[i]['predictors']
    with T.no_grad():
      Y = model(X)  # should be same as X
    err = T.sum((X-Y)*(X-Y)).item()  # SSE all features
    err = err / n_features           # sort of norm'ed SSE 

    if err "gt" largest_err:  # replace gt with operator
      largest_err = err
      worst_x = X
      worst_y = Y
      worst_idx = i

  np.set_printoptions(formatter={'float': '{: 0.4f}'.format})
  print("\nLargest reconstruction idx: " + str(worst_idx))
  print("\nLargest reconstruction item: ")
  print(worst_x.numpy())
  print("\nLargest reconstruction error: %0.4f" % largest_err)
  print("\nIts reconstruction = " )
  print(worst_y.numpy())

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin patient transformer-based anomaly detect ")
  T.manual_seed(0)
  np.random.seed(0)
  
  # 1. create DataLoader objects
  print("\nCreating Patient Dataset ")

  data_file = ".\\Data\\medical_data_200.txt"
  data_ds = PatientDataset(data_file)  # 200 rows

  bat_size = 10
  data_ldr = T.utils.data.DataLoader(data_ds,
    batch_size=bat_size, shuffle=True)

  # 2. create network
  print("\nCreating Transformer encoder-decoder network ")
  net = Transformer_Net().to(device)

# -----------------------------------------------------------

  # 3. train autoencoder model
  max_epochs = 100
  ep_log_interval = 10
  # lrn_rate = 0.005
  lrn_rate = 0.010

  loss_func = T.nn.MSELoss()
  optimizer = T.optim.Adam(net.parameters(), lr=lrn_rate)

  print("\nbat_size = %3d " % bat_size)
  print("loss = " + str(loss_func))
  print("optimizer = Adam")
  print("lrn_rate = %0.3f " % lrn_rate)
  print("max_epochs = %3d " % max_epochs)
  
  print("\nStarting training")
  net.train()
  for epoch in range(0, max_epochs):
    epoch_loss = 0  # for one full epoch

    for (batch_idx, batch) in enumerate(data_ldr):
      X = batch['predictors'] 
      Y = batch['predictors'] 

      optimizer.zero_grad()
      oupt = net(X)
      loss_val = loss_func(oupt, Y)  # a tensor
      epoch_loss += loss_val.item()  # accumulate
      loss_val.backward()
      optimizer.step()

    if epoch % ep_log_interval == 0:
      print("epoch = %4d  |  loss = %0.4f" % \
       (epoch, epoch_loss))
  print("Done ")

# -----------------------------------------------------------

  # 4. find item with largest reconstruction error
  print("\nAnalyzing data for largest reconstruction error ")
  net.eval()
  analyze_error(net, data_ds)

  print("\nEnd transformer autoencoder anomaly demo ")

if __name__ == "__main__":
  main()

Synthetic data:

# medical_data_200.txt
#
0.1668, 0.2881, 0.1000, 0.4209, 0.2587, 0.6369, 0.5745, 0.6382, 0.4587, 0.3155, 0.1677, 0.3741
0.0818, 0.3512, 0.1110, 0.5682, 0.3669, 0.8235, 0.5562, 0.5792, 0.6203, 0.4873, 0.1254, 0.3769
0.3506, 0.3578, 0.1340, 0.3156, 0.2679, 0.9513, 0.5393, 0.6684, 0.6832, 0.3133, 0.2768, 0.2262
0.2746, 0.3339, 0.1073, 0.6001, 0.5955, 0.8993, 0.6122, 0.8157, 0.3413, 0.2792, 0.3634, 0.2174
0.1151, 0.0520, 0.1077, 0.5715, 0.2847, 0.7062, 0.6966, 0.5213, 0.5296, 0.1587, 0.2357, 0.3799
0.0409, 0.1656, 0.3778, 0.4657, 0.2200, 0.8144, 0.7655, 0.7060, 0.6778, 0.3346, 0.3614, 0.1550
0.0557, 0.3230, 0.2591, 0.3661, 0.5710, 0.7391, 0.8003, 0.7904, 0.6533, 0.3495, 0.3004, 0.2396
0.1080, 0.3584, 0.2712, 0.6859, 0.4654, 0.8487, 0.5459, 0.8798, 0.4800, 0.3314, 0.1633, 0.1948
0.3614, 0.2295, 0.1011, 0.5469, 0.3307, 0.8108, 0.8544, 0.6429, 0.6634, 0.3493, 0.0063, 0.4718
0.2764, 0.3989, 0.1689, 0.3549, 0.5730, 0.8787, 0.5264, 0.8022, 0.6016, 0.4692, 0.2846, 0.1497
0.0080, 0.0105, 0.1113, 0.3985, 0.5440, 0.8155, 0.7211, 0.8368, 0.3497, 0.2117, 0.2343, 0.4878
0.2244, 0.0075, 0.4203, 0.3932, 0.5228, 0.7551, 0.8454, 0.7988, 0.5225, 0.1546, 0.0240, 0.1485
0.0178, 0.0430, 0.1903, 0.5852, 0.4239, 0.6050, 0.5288, 0.8869, 0.5272, 0.1813, 0.1009, 0.3975
0.0782, 0.2325, 0.4880, 0.6387, 0.2959, 0.7975, 0.7480, 0.8316, 0.3627, 0.1074, 0.0280, 0.2945
0.2425, 0.2275, 0.2269, 0.6954, 0.4319, 0.7521, 0.7204, 0.7981, 0.5677, 0.2060, 0.0265, 0.2480
0.2519, 0.0841, 0.4011, 0.3266, 0.3041, 0.9219, 0.5774, 0.7558, 0.5099, 0.4699, 0.1053, 0.1264
0.2940, 0.3089, 0.4631, 0.6728, 0.2056, 0.6937, 0.7467, 0.8796, 0.6801, 0.3227, 0.3662, 0.3566
0.1560, 0.1944, 0.3417, 0.5198, 0.5705, 0.9675, 0.6580, 0.8853, 0.3696, 0.1505, 0.0540, 0.3023
0.0086, 0.3792, 0.4308, 0.3060, 0.2705, 0.7328, 0.5524, 0.8238, 0.4379, 0.4760, 0.2328, 0.4515
0.3379, 0.3622, 0.2840, 0.5185, 0.5194, 0.7143, 0.6961, 0.7396, 0.3062, 0.3374, 0.1735, 0.4229
0.1261, 0.3572, 0.3311, 0.3736, 0.5152, 0.8448, 0.5216, 0.6681, 0.5716, 0.4674, 0.0002, 0.4907
0.1506, 0.3895, 0.3419, 0.6315, 0.4299, 0.8512, 0.6142, 0.7347, 0.6000, 0.4433, 0.3020, 0.3792
0.3458, 0.1291, 0.3683, 0.4803, 0.3528, 0.7643, 0.6606, 0.6270, 0.5488, 0.2721, 0.3895, 0.3711
0.0794, 0.1707, 0.2373, 0.6191, 0.5520, 0.9615, 0.7651, 0.6081, 0.4009, 0.4420, 0.2111, 0.4209
0.2290, 0.2933, 0.3076, 0.6084, 0.4275, 0.7863, 0.6371, 0.5273, 0.4512, 0.1319, 0.3931, 0.1726
0.3247, 0.3500, 0.3754, 0.5278, 0.2644, 0.7868, 0.6381, 0.5900, 0.5370, 0.2249, 0.3665, 0.4639
0.1028, 0.0444, 0.1772, 0.4998, 0.4914, 0.6833, 0.5992, 0.8407, 0.4663, 0.3467, 0.0935, 0.1408
0.2063, 0.1909, 0.1611, 0.5487, 0.4176, 0.8617, 0.5578, 0.8006, 0.3888, 0.3077, 0.3141, 0.1089
0.1297, 0.3492, 0.4379, 0.5154, 0.5466, 0.9799, 0.8306, 0.8416, 0.3395, 0.3605, 0.2814, 0.3441
0.3198, 0.0138, 0.4081, 0.5927, 0.3039, 0.7028, 0.7529, 0.6381, 0.6186, 0.2785, 0.3131, 0.4962
0.1201, 0.0572, 0.4605, 0.5166, 0.5899, 0.8546, 0.8976, 0.7184, 0.5106, 0.1542, 0.1423, 0.1105
0.0642, 0.2983, 0.1122, 0.4466, 0.5449, 0.8771, 0.7764, 0.5755, 0.4768, 0.3326, 0.3959, 0.1816
0.0991, 0.1049, 0.4001, 0.4828, 0.2228, 0.8034, 0.5848, 0.8194, 0.4189, 0.1110, 0.2374, 0.4375
0.1524, 0.2999, 0.3045, 0.5164, 0.5838, 0.9216, 0.5129, 0.7838, 0.4860, 0.4790, 0.0886, 0.2068
0.0326, 0.1714, 0.1436, 0.5535, 0.5212, 0.8787, 0.8065, 0.6370, 0.6383, 0.2715, 0.3296, 0.3506
0.0574, 0.0314, 0.1073, 0.3267, 0.3834, 0.6453, 0.5111, 0.8019, 0.4579, 0.3988, 0.1810, 0.2800
0.1912, 0.1896, 0.4213, 0.4610, 0.5619, 0.6148, 0.8095, 0.5503, 0.5474, 0.1041, 0.2155, 0.1012
0.3805, 0.3622, 0.4184, 0.6661, 0.2582, 0.6631, 0.5751, 0.7490, 0.6623, 0.4960, 0.2844, 0.3927
0.3637, 0.1603, 0.1999, 0.3694, 0.2478, 0.9250, 0.5587, 0.6057, 0.6276, 0.2242, 0.3930, 0.2067
0.2135, 0.1258, 0.4643, 0.4466, 0.3734, 0.8049, 0.8756, 0.5124, 0.5868, 0.4564, 0.0109, 0.3088
0.1304, 0.3438, 0.3234, 0.5761, 0.3811, 0.8513, 0.6160, 0.5037, 0.5307, 0.2246, 0.2069, 0.4666
0.1706, 0.0990, 0.2485, 0.6727, 0.5747, 0.9377, 0.8681, 0.5912, 0.3350, 0.1909, 0.1258, 0.1699
0.2428, 0.1654, 0.4265, 0.3741, 0.4808, 0.6961, 0.7297, 0.6396, 0.3228, 0.1915, 0.2656, 0.2989
0.2076, 0.0699, 0.3283, 0.6987, 0.5267, 0.8377, 0.8904, 0.8606, 0.5382, 0.1130, 0.0374, 0.1261
0.1807, 0.1502, 0.4901, 0.3672, 0.5891, 0.9070, 0.8297, 0.7530, 0.5675, 0.2908, 0.0053, 0.2412
0.1968, 0.2920, 0.2875, 0.4830, 0.2551, 0.6044, 0.8033, 0.6280, 0.6938, 0.1881, 0.1355, 0.3096
0.3020, 0.1855, 0.1499, 0.4250, 0.4018, 0.8695, 0.8081, 0.5521, 0.3092, 0.3076, 0.3240, 0.1050
0.2690, 0.2747, 0.2797, 0.6659, 0.4577, 0.6021, 0.6938, 0.8437, 0.6322, 0.3597, 0.2695, 0.3314
0.1096, 0.2242, 0.3687, 0.4410, 0.5423, 0.6780, 0.7989, 0.6158, 0.6095, 0.2711, 0.3231, 0.2414
0.0855, 0.3069, 0.2235, 0.5933, 0.4978, 0.6886, 0.5856, 0.5796, 0.3570, 0.2508, 0.0107, 0.1444
0.2698, 0.3199, 0.1322, 0.3927, 0.2831, 0.9669, 0.7845, 0.7216, 0.4218, 0.4339, 0.1741, 0.4694
0.2824, 0.1912, 0.1505, 0.6904, 0.2639, 0.6810, 0.6725, 0.6617, 0.3587, 0.3917, 0.0755, 0.3576
0.3017, 0.0843, 0.3404, 0.5996, 0.4553, 0.8389, 0.6182, 0.7926, 0.6781, 0.2702, 0.3129, 0.1225
0.3341, 0.0769, 0.2580, 0.4200, 0.2320, 0.9619, 0.6481, 0.7123, 0.4976, 0.1529, 0.0826, 0.1305
0.2032, 0.1046, 0.2428, 0.3432, 0.5150, 0.6426, 0.8943, 0.5709, 0.5290, 0.1179, 0.3148, 0.1758
0.2112, 0.2960, 0.1600, 0.5204, 0.2866, 0.9037, 0.7892, 0.5706, 0.6448, 0.1079, 0.3441, 0.3236
0.1613, 0.3035, 0.3868, 0.6949, 0.3112, 0.6015, 0.8736, 0.8432, 0.5915, 0.3067, 0.2828, 0.4122
0.1500, 0.3081, 0.4002, 0.5453, 0.3607, 0.8789, 0.5012, 0.8100, 0.6586, 0.1957, 0.0483, 0.1881
0.1208, 0.3532, 0.3173, 0.4147, 0.2553, 0.7161, 0.7455, 0.6297, 0.4829, 0.2776, 0.3313, 0.2705
0.1383, 0.2700, 0.1886, 0.4869, 0.3259, 0.8507, 0.8509, 0.6791, 0.6138, 0.2828, 0.2625, 0.1527
0.1732, 0.3637, 0.3422, 0.6067, 0.4019, 0.7992, 0.8372, 0.5271, 0.5293, 0.4771, 0.2071, 0.1778
0.3392, 0.1007, 0.3803, 0.5161, 0.5795, 0.8497, 0.8352, 0.5032, 0.6957, 0.1311, 0.1289, 0.4785
0.0036, 0.3291, 0.4445, 0.4759, 0.3023, 0.9211, 0.6911, 0.5537, 0.6711, 0.4584, 0.1966, 0.4427
0.1674, 0.2734, 0.2592, 0.5023, 0.2758, 0.9860, 0.6177, 0.5414, 0.3577, 0.1056, 0.2864, 0.3258
0.3178, 0.2028, 0.4167, 0.5783, 0.5111, 0.7626, 0.7591, 0.5719, 0.4287, 0.1690, 0.1635, 0.1966
0.1628, 0.3901, 0.2281, 0.6930, 0.4545, 0.7500, 0.8430, 0.7478, 0.4008, 0.4171, 0.1732, 0.2430
0.1321, 0.2789, 0.2075, 0.6233, 0.3181, 0.8176, 0.6952, 0.8421, 0.6554, 0.1738, 0.2341, 0.4593
0.1784, 0.3687, 0.2116, 0.5435, 0.4730, 0.6913, 0.5055, 0.6667, 0.6754, 0.2372, 0.3119, 0.1699
0.1368, 0.0578, 0.3867, 0.5797, 0.4754, 0.7014, 0.7769, 0.5909, 0.4699, 0.2488, 0.1421, 0.1231
0.2527, 0.2829, 0.3454, 0.5593, 0.2680, 0.6598, 0.7057, 0.8501, 0.3736, 0.2851, 0.1716, 0.2989
0.0646, 0.1370, 0.2048, 0.6378, 0.5201, 0.7707, 0.7428, 0.5582, 0.5038, 0.2188, 0.3439, 0.3686
0.2534, 0.0499, 0.2882, 0.6946, 0.5793, 0.8580, 0.5607, 0.7557, 0.5263, 0.2875, 0.1712, 0.3397
0.3400, 0.3004, 0.3317, 0.6699, 0.2259, 0.9965, 0.5212, 0.5798, 0.4691, 0.1430, 0.2495, 0.1192
0.1138, 0.0244, 0.3814, 0.5674, 0.3514, 0.6753, 0.7988, 0.6362, 0.6181, 0.2952, 0.2103, 0.1114
0.2577, 0.1403, 0.1917, 0.4736, 0.3530, 0.7879, 0.8918, 0.6458, 0.6098, 0.3211, 0.3557, 0.2420
0.0982, 0.3644, 0.1174, 0.6803, 0.4226, 0.7505, 0.8980, 0.5233, 0.5067, 0.1124, 0.2285, 0.1722
0.2524, 0.3924, 0.4500, 0.4807, 0.4834, 0.9110, 0.6979, 0.7114, 0.3603, 0.2478, 0.0569, 0.3908
0.1908, 0.1796, 0.4544, 0.5110, 0.3636, 0.7076, 0.5288, 0.6673, 0.3103, 0.2165, 0.2014, 0.4864
0.0438, 0.2692, 0.3000, 0.6108, 0.2574, 0.6333, 0.6597, 0.8188, 0.3767, 0.4071, 0.1161, 0.1868
0.0067, 0.1595, 0.2524, 0.5637, 0.2284, 0.6610, 0.5066, 0.5455, 0.5607, 0.2611, 0.1284, 0.3232
0.3974, 0.3338, 0.3798, 0.6673, 0.2159, 0.6281, 0.6896, 0.6397, 0.6749, 0.2958, 0.2159, 0.4581
0.1787, 0.3508, 0.2014, 0.4095, 0.3313, 0.8190, 0.5881, 0.7686, 0.3571, 0.1376, 0.3481, 0.1947
0.1544, 0.2286, 0.3103, 0.3304, 0.5497, 0.9805, 0.8250, 0.6135, 0.5111, 0.2358, 0.2219, 0.4898
0.1247, 0.2675, 0.2304, 0.6098, 0.3303, 0.9559, 0.8007, 0.8051, 0.4878, 0.1843, 0.0166, 0.2287
0.0148, 0.2775, 0.3681, 0.4722, 0.5071, 0.8144, 0.5159, 0.5539, 0.3774, 0.2343, 0.0209, 0.3420
0.2048, 0.2470, 0.2729, 0.6391, 0.3816, 0.6062, 0.8492, 0.7625, 0.6292, 0.4807, 0.0204, 0.1940
0.0253, 0.1687, 0.4455, 0.3326, 0.3892, 0.6502, 0.8092, 0.8366, 0.3173, 0.2946, 0.0958, 0.4810
0.3776, 0.2456, 0.4894, 0.4379, 0.5591, 0.7738, 0.5943, 0.8763, 0.5737, 0.1260, 0.3482, 0.3806
0.2420, 0.2929, 0.2014, 0.5402, 0.5258, 0.6216, 0.5522, 0.8370, 0.5473, 0.3125, 0.0993, 0.2180
0.3491, 0.1687, 0.1258, 0.6588, 0.2814, 0.9305, 0.8527, 0.6947, 0.5394, 0.3109, 0.2499, 0.4420
0.1129, 0.3535, 0.3271, 0.3460, 0.2908, 0.8384, 0.5958, 0.5526, 0.3647, 0.4379, 0.2409, 0.4854
0.1383, 0.2383, 0.3396, 0.5463, 0.2237, 0.9001, 0.8793, 0.7139, 0.3770, 0.4012, 0.0029, 0.2313
0.3670, 0.2353, 0.4421, 0.5419, 0.5290, 0.9518, 0.6284, 0.5492, 0.5885, 0.2761, 0.0507, 0.3359
0.0144, 0.0801, 0.4153, 0.3048, 0.3213, 0.6086, 0.8990, 0.7328, 0.4174, 0.4716, 0.2028, 0.2819
0.2351, 0.1057, 0.2221, 0.4487, 0.2978, 0.8338, 0.7783, 0.5288, 0.6884, 0.4012, 0.3225, 0.4007
0.0320, 0.1927, 0.2783, 0.5690, 0.3795, 0.8817, 0.7727, 0.7789, 0.5474, 0.1604, 0.3043, 0.4124
0.3616, 0.0935, 0.1707, 0.4564, 0.3282, 0.9262, 0.7454, 0.8040, 0.4711, 0.1398, 0.0460, 0.2494
0.0775, 0.3283, 0.3399, 0.5755, 0.3964, 0.6353, 0.5940, 0.6846, 0.3794, 0.1102, 0.2918, 0.3900
0.1322, 0.3374, 0.2714, 0.6459, 0.4628, 0.8324, 0.5803, 0.7118, 0.6578, 0.2220, 0.3484, 0.4635
0.1319, 0.2732, 0.4597, 0.3303, 0.5514, 0.6763, 0.8399, 0.7668, 0.4377, 0.1606, 0.2541, 0.4391
0.3287, 0.2513, 0.4825, 0.5360, 0.2791, 0.7717, 0.6347, 0.8968, 0.4521, 0.4971, 0.2075, 0.1689
0.0298, 0.1481, 0.1494, 0.5538, 0.3656, 0.9964, 0.8720, 0.5597, 0.4580, 0.2846, 0.2244, 0.4121
0.1949, 0.1680, 0.2048, 0.6643, 0.2089, 0.9284, 0.5754, 0.7743, 0.4421, 0.4897, 0.0491, 0.1750
0.3558, 0.2334, 0.2237, 0.3003, 0.2910, 0.6582, 0.5814, 0.8585, 0.6492, 0.3801, 0.1882, 0.4309
0.1983, 0.1454, 0.2102, 0.6699, 0.3548, 0.7972, 0.6018, 0.8540, 0.4533, 0.2190, 0.2870, 0.1763
0.0473, 0.3348, 0.3977, 0.5362, 0.2972, 0.8493, 0.7553, 0.6310, 0.3270, 0.4522, 0.1840, 0.4055
0.1016, 0.2366, 0.2715, 0.4528, 0.2507, 0.6977, 0.5317, 0.6211, 0.5967, 0.3460, 0.2690, 0.1034
0.2714, 0.2013, 0.1924, 0.3700, 0.2740, 0.9377, 0.8930, 0.8655, 0.4389, 0.4121, 0.2186, 0.4266
0.1935, 0.2360, 0.4149, 0.3401, 0.4148, 0.7464, 0.7417, 0.8835, 0.4571, 0.2572, 0.3163, 0.3580
0.1576, 0.2756, 0.2616, 0.3544, 0.3803, 0.7338, 0.5872, 0.8703, 0.5759, 0.3395, 0.2987, 0.3168
0.2802, 0.3722, 0.4450, 0.3670, 0.3053, 0.6286, 0.8915, 0.5946, 0.5642, 0.1359, 0.0843, 0.3011
0.0420, 0.1555, 0.3152, 0.4357, 0.4224, 0.8147, 0.6562, 0.7785, 0.5714, 0.3749, 0.2246, 0.2432
0.2452, 0.3743, 0.3388, 0.6918, 0.3764, 0.8958, 0.5150, 0.8059, 0.5073, 0.1021, 0.1109, 0.3139
0.3072, 0.0212, 0.3196, 0.6204, 0.4598, 0.9726, 0.5299, 0.6107, 0.6677, 0.4060, 0.2399, 0.4332
0.3584, 0.3891, 0.4994, 0.3559, 0.2282, 0.6294, 0.5059, 0.8887, 0.3379, 0.4367, 0.2741, 0.2950
0.1387, 0.1415, 0.2015, 0.6644, 0.4903, 0.6104, 0.6846, 0.6125, 0.3116, 0.4539, 0.3084, 0.2319
0.3186, 0.1299, 0.2232, 0.6712, 0.5908, 0.8094, 0.8808, 0.8552, 0.5072, 0.2491, 0.2841, 0.2823
0.2421, 0.3962, 0.4096, 0.4337, 0.2356, 0.6740, 0.7107, 0.6668, 0.6203, 0.4733, 0.0711, 0.4440
0.3830, 0.3838, 0.1151, 0.3230, 0.2023, 0.7251, 0.5223, 0.6175, 0.4424, 0.4815, 0.1908, 0.2113
0.2012, 0.2573, 0.1445, 0.6032, 0.5408, 0.9377, 0.8562, 0.6527, 0.4775, 0.1406, 0.0903, 0.4885
0.1138, 0.3630, 0.4562, 0.6690, 0.2625, 0.9519, 0.7528, 0.5843, 0.4380, 0.1640, 0.1777, 0.1335
0.0661, 0.0789, 0.3764, 0.5295, 0.5493, 0.6979, 0.7518, 0.5138, 0.5065, 0.4453, 0.3937, 0.1039
0.1011, 0.0895, 0.1190, 0.3041, 0.3961, 0.6182, 0.6113, 0.7560, 0.4179, 0.2079, 0.2362, 0.2522
0.2810, 0.1984, 0.3533, 0.4414, 0.3148, 0.6533, 0.8748, 0.8219, 0.6752, 0.1742, 0.3733, 0.4743
0.1134, 0.1099, 0.3206, 0.3741, 0.3768, 0.6741, 0.8790, 0.6975, 0.6856, 0.3580, 0.1937, 0.4871
0.0573, 0.2529, 0.3648, 0.4735, 0.2237, 0.7971, 0.6861, 0.8227, 0.4027, 0.2566, 0.0961, 0.3746
0.3960, 0.0710, 0.4286, 0.3924, 0.2232, 0.6555, 0.8741, 0.8554, 0.4157, 0.4791, 0.3400, 0.2739
0.1872, 0.2519, 0.1632, 0.3059, 0.3062, 0.6062, 0.7698, 0.7206, 0.4287, 0.4121, 0.0583, 0.1980
0.1169, 0.0784, 0.1352, 0.6480, 0.2353, 0.8735, 0.5482, 0.5043, 0.5229, 0.4628, 0.3442, 0.2354
0.0109, 0.3203, 0.4224, 0.6474, 0.4679, 0.9231, 0.8590, 0.6815, 0.5231, 0.3025, 0.2768, 0.3732
0.2081, 0.3314, 0.3023, 0.6299, 0.3127, 0.6714, 0.8879, 0.7968, 0.4039, 0.3325, 0.3821, 0.1323
0.0334, 0.2477, 0.1898, 0.6061, 0.4273, 0.8665, 0.5431, 0.5337, 0.5500, 0.2639, 0.0349, 0.2484
0.2689, 0.0758, 0.4583, 0.6799, 0.5846, 0.8920, 0.6625, 0.7975, 0.4152, 0.2258, 0.2424, 0.3379
0.3515, 0.1018, 0.4065, 0.6764, 0.2004, 0.7904, 0.7628, 0.8373, 0.3735, 0.4425, 0.1457, 0.4569
0.0111, 0.0342, 0.4926, 0.5438, 0.3671, 0.6677, 0.7600, 0.5148, 0.4246, 0.2294, 0.2430, 0.3603
0.3384, 0.3710, 0.3642, 0.5313, 0.3595, 0.9866, 0.5616, 0.8580, 0.4244, 0.3194, 0.2728, 0.1946
0.0671, 0.2034, 0.4167, 0.5770, 0.2476, 0.9603, 0.6919, 0.8787, 0.5214, 0.1339, 0.0810, 0.4417
0.2824, 0.3580, 0.2317, 0.5112, 0.4602, 0.8377, 0.5926, 0.6707, 0.3992, 0.4382, 0.3947, 0.1266
0.2856, 0.1320, 0.3502, 0.4033, 0.4604, 0.7485, 0.6179, 0.8647, 0.6770, 0.3454, 0.0858, 0.4758
0.3002, 0.3004, 0.1847, 0.6321, 0.2960, 0.8526, 0.7843, 0.7970, 0.4561, 0.4039, 0.3612, 0.3350
0.0362, 0.0516, 0.1421, 0.3691, 0.2506, 0.9113, 0.5158, 0.5966, 0.6516, 0.4894, 0.2422, 0.4905
0.0174, 0.3794, 0.2245, 0.6196, 0.5243, 0.9440, 0.6834, 0.8723, 0.4032, 0.4738, 0.2476, 0.4942
0.0131, 0.2901, 0.3262, 0.4970, 0.3037, 0.7307, 0.5998, 0.5877, 0.6199, 0.3010, 0.0333, 0.4108
0.2135, 0.2958, 0.3062, 0.4600, 0.5945, 0.6113, 0.8731, 0.8723, 0.4564, 0.1858, 0.2477, 0.1712
0.3213, 0.1016, 0.2163, 0.6664, 0.5612, 0.8142, 0.8451, 0.6410, 0.6992, 0.2735, 0.1179, 0.1166
0.3959, 0.1860, 0.3938, 0.5563, 0.4892, 0.6204, 0.8680, 0.8707, 0.5208, 0.4856, 0.1124, 0.3409
0.1586, 0.2512, 0.2203, 0.6277, 0.2279, 0.6168, 0.5198, 0.5602, 0.4581, 0.4822, 0.0443, 0.3590
0.2110, 0.1413, 0.1793, 0.3882, 0.2175, 0.8853, 0.7615, 0.6775, 0.5876, 0.1440, 0.3755, 0.4391
0.2636, 0.1515, 0.2666, 0.4929, 0.5741, 0.9454, 0.6912, 0.7218, 0.6502, 0.4797, 0.2557, 0.3994
0.1406, 0.3672, 0.4347, 0.5208, 0.5471, 0.9399, 0.8234, 0.5523, 0.5144, 0.4603, 0.3083, 0.2683
0.3912, 0.1003, 0.2334, 0.6817, 0.5235, 0.9601, 0.5046, 0.6519, 0.3942, 0.2184, 0.2952, 0.3896
0.1847, 0.1461, 0.3339, 0.5135, 0.4202, 0.8462, 0.6583, 0.8087, 0.4005, 0.3623, 0.3842, 0.1014
0.2893, 0.0436, 0.3175, 0.5508, 0.2972, 0.9655, 0.7489, 0.5927, 0.6081, 0.1422, 0.2221, 0.1380
0.2324, 0.1012, 0.3598, 0.5863, 0.4097, 0.8630, 0.8253, 0.8230, 0.5479, 0.2804, 0.1632, 0.4499
0.2812, 0.0740, 0.3263, 0.6635, 0.2603, 0.9382, 0.8281, 0.6997, 0.3150, 0.1590, 0.3691, 0.1177
0.1490, 0.2476, 0.1788, 0.6924, 0.2624, 0.8159, 0.7472, 0.5924, 0.4865, 0.2964, 0.3067, 0.3537
0.1870, 0.2668, 0.2814, 0.5103, 0.4634, 0.6725, 0.8152, 0.7411, 0.6921, 0.1890, 0.1865, 0.4583
0.2406, 0.2904, 0.2655, 0.5737, 0.5784, 0.6801, 0.8229, 0.8878, 0.3058, 0.1861, 0.0470, 0.4070
0.0325, 0.3994, 0.2558, 0.6621, 0.2754, 0.7763, 0.8379, 0.7381, 0.4849, 0.3051, 0.2296, 0.2761
0.0774, 0.0456, 0.1605, 0.3210, 0.4235, 0.9391, 0.6436, 0.5301, 0.4703, 0.3408, 0.1702, 0.4243
0.3768, 0.3480, 0.3816, 0.4881, 0.4731, 0.6054, 0.8930, 0.5688, 0.6625, 0.4228, 0.3201, 0.2076
0.0877, 0.1715, 0.4397, 0.4802, 0.5742, 0.7411, 0.7478, 0.8923, 0.5574, 0.1182, 0.1854, 0.2339
0.2129, 0.1139, 0.1489, 0.5673, 0.4725, 0.9469, 0.5530, 0.8194, 0.6307, 0.3586, 0.0820, 0.2046
0.1674, 0.2167, 0.2910, 0.4870, 0.5359, 0.9480, 0.8267, 0.8511, 0.5284, 0.4856, 0.2426, 0.3416
0.1277, 0.2725, 0.1208, 0.3538, 0.2521, 0.8621, 0.5701, 0.6365, 0.3177, 0.1935, 0.3857, 0.3037
0.0604, 0.2092, 0.4774, 0.6463, 0.3568, 0.7135, 0.8047, 0.5962, 0.4017, 0.1336, 0.3457, 0.2792
0.2247, 0.2947, 0.4186, 0.4790, 0.2737, 0.9315, 0.5124, 0.8787, 0.5308, 0.4502, 0.2434, 0.2007
0.1185, 0.2132, 0.4848, 0.3738, 0.4040, 0.7375, 0.8079, 0.8211, 0.4698, 0.1816, 0.0268, 0.1795
0.1090, 0.2395, 0.4492, 0.3513, 0.5833, 0.8733, 0.7968, 0.8932, 0.4664, 0.3126, 0.2717, 0.3052
0.1196, 0.0422, 0.2140, 0.6075, 0.4563, 0.9205, 0.7068, 0.5928, 0.5512, 0.2216, 0.0118, 0.4613
0.1681, 0.1705, 0.3965, 0.6794, 0.2290, 0.6691, 0.6289, 0.5994, 0.4982, 0.4298, 0.3416, 0.3383
0.0961, 0.0746, 0.3974, 0.3949, 0.4158, 0.8997, 0.5913, 0.5698, 0.3159, 0.1590, 0.3406, 0.1419
0.0895, 0.0131, 0.3705, 0.3927, 0.3654, 0.6557, 0.6509, 0.6698, 0.4892, 0.2691, 0.0411, 0.4090
0.0549, 0.1697, 0.2088, 0.4204, 0.4690, 0.8071, 0.5760, 0.6877, 0.4358, 0.3818, 0.0904, 0.4380
0.2105, 0.2002, 0.2015, 0.3745, 0.3887, 0.9927, 0.5532, 0.6134, 0.6203, 0.3659, 0.1115, 0.2259
0.1680, 0.2411, 0.3877, 0.6429, 0.4724, 0.6948, 0.8703, 0.8125, 0.4230, 0.2220, 0.3525, 0.1504
0.2534, 0.1111, 0.4309, 0.4458, 0.4933, 0.6770, 0.6926, 0.8214, 0.4588, 0.1646, 0.2596, 0.4013
0.1121, 0.1805, 0.4602, 0.6488, 0.4829, 0.8364, 0.8270, 0.8631, 0.3010, 0.2589, 0.2246, 0.3936
0.0350, 0.1599, 0.2118, 0.4694, 0.3992, 0.8640, 0.6985, 0.7482, 0.5330, 0.2713, 0.0020, 0.1778
0.1281, 0.0588, 0.3395, 0.5446, 0.4000, 0.7283, 0.7613, 0.5761, 0.3024, 0.3940, 0.1774, 0.3791
0.0740, 0.0250, 0.2512, 0.5784, 0.2411, 0.6783, 0.6816, 0.7485, 0.6000, 0.1439, 0.2498, 0.2549
0.2680, 0.0814, 0.3112, 0.3689, 0.2075, 0.7948, 0.5737, 0.7553, 0.5146, 0.4100, 0.1572, 0.4958
0.2061, 0.1915, 0.3998, 0.5291, 0.3450, 0.7957, 0.5757, 0.6574, 0.3120, 0.2850, 0.1098, 0.3107
0.0117, 0.2220, 0.2172, 0.5310, 0.4931, 0.7761, 0.7653, 0.5956, 0.6994, 0.1972, 0.3763, 0.1869
0.1990, 0.3285, 0.3866, 0.5822, 0.3762, 0.9017, 0.8680, 0.6765, 0.5112, 0.1264, 0.1563, 0.2869
0.3581, 0.0442, 0.3925, 0.5182, 0.4426, 0.6119, 0.5587, 0.6136, 0.3019, 0.3677, 0.3481, 0.3188
0.2173, 0.2463, 0.2209, 0.4467, 0.4300, 0.9237, 0.5806, 0.6310, 0.5972, 0.2364, 0.0190, 0.1625
0.0775, 0.1980, 0.3540, 0.6521, 0.5610, 0.7229, 0.8014, 0.6130, 0.4474, 0.2171, 0.1655, 0.1859
0.1276, 0.0488, 0.4852, 0.5016, 0.5692, 0.8985, 0.6831, 0.8018, 0.5512, 0.2215, 0.2087, 0.3849
0.1221, 0.0050, 0.2073, 0.6187, 0.5720, 0.8501, 0.5531, 0.8030, 0.5108, 0.4015, 0.3434, 0.4790
0.2618, 0.3417, 0.3970, 0.5908, 0.5435, 0.9692, 0.8608, 0.6583, 0.3336, 0.4318, 0.2156, 0.3168
0.3301, 0.0128, 0.4512, 0.3139, 0.4773, 0.8350, 0.7567, 0.6496, 0.4102, 0.3038, 0.3543, 0.3261
0.2653, 0.1766, 0.4889, 0.5970, 0.3420, 0.8614, 0.7170, 0.8536, 0.4100, 0.1432, 0.0765, 0.2548
0.3416, 0.1083, 0.3505, 0.5494, 0.3632, 0.6201, 0.7979, 0.6183, 0.4594, 0.2509, 0.2654, 0.1345
0.2200, 0.0062, 0.4932, 0.5394, 0.3536, 0.6587, 0.7788, 0.8623, 0.4272, 0.4066, 0.1150, 0.4829
0.2809, 0.2500, 0.4723, 0.4076, 0.5694, 0.8712, 0.8085, 0.7287, 0.6336, 0.3793, 0.0586, 0.3450
0.1117, 0.3664, 0.1793, 0.4143, 0.2191, 0.7790, 0.7230, 0.7294, 0.6622, 0.2390, 0.1790, 0.4405
0.2307, 0.2616, 0.2113, 0.4950, 0.4484, 0.6534, 0.5132, 0.5454, 0.4910, 0.1096, 0.2505, 0.1390
0.1004, 0.1706, 0.1463, 0.4082, 0.2084, 0.9940, 0.7446, 0.6513, 0.3106, 0.2559, 0.1810, 0.4724
0.1114, 0.2459, 0.3661, 0.3744, 0.4023, 0.9146, 0.5386, 0.7424, 0.3104, 0.1028, 0.0238, 0.2926
This entry was posted in PyTorch, Transformers. Bookmark the permalink.

Leave a comment