PyTorchにおける`torch.norm`の多態性とベクトル・行列・高次元テンソルのノルム計算

PyTorchのtorch.norm関数は、テンソルのノルム(大きさ)を計算するための強力なツールです。特徴的なのは、入力データの次元に応じて動的に処理を切り替える多態性ポリモーフィズム)を持っている点です。本記事では、torch.normがどのように多態性を実現し、ベクトルや行列、高次元テンソルに対してどのように機能するのかを具体例とともに解説します。


1. torch.normとは?

torch.normは、テンソルのノルムを計算する関数です。ノルムとは、ベクトルや行列の「長さ」や「大きさ」を示す数値で、幾何学的にはユークリッド距離に対応することが多いです。

この関数はデフォルトでL2ノルム(ユークリッドノルム)やFrobeniusノルムを計算しますが、引数を指定することで他の種類のノルムも計算可能です。


2. torch.norm多態性の特徴

torch.normは、入力テンソルの次元や形状に応じて異なる計算方法を採用します。具体的には以下のような動作をします:

  1. 1次元テンソル(ベクトル):L2ノルムを計算
  2. 2次元テンソル(行列):Frobeniusノルムを計算
  3. 高次元テンソル:指定された次元(軸)に沿ったノルムを計算

次に、各ケースを例を挙げて解説します。


2.1. ベクトルのノルム計算

ベクトル(1次元テンソル)では、デフォルトでL2ノルムが計算されます。L2ノルムは各成分の二乗和の平方根です。

例:

import torch

u = torch.tensor([3.0, -4.0])  # 1次元ベクトル
result = torch.norm(u)  # デフォルトでL2ノルム
print(result)  # 結果: tensor(5.)

この場合、ベクトルの形状が自動的に判別され、適切なL2ノルムが計算されます。


2.2. 行列のノルム計算

行列(2次元テンソル)では、デフォルトでFrobeniusノルムが計算されます。これは全要素の二乗和の平方根です。

例:

A = torch.ones((4, 9))  # 2次元行列
result = torch.norm(A)  # デフォルトでFrobeniusノルム
print(result)  # 結果: tensor(6.)

ここでもテンソルの次元に応じた適切なノルムが選択されます。


2.3. 高次元テンソルのノルム計算

高次元テンソル(3次元以上)では、dim引数を用いることで特定の次元に沿ったノルム計算が可能です。

例:

B = torch.ones((2, 3, 4))  # 3次元テンソル
result = torch.norm(B, dim=(1, 2))  # 1軸と2軸に沿ったノルム計算
print(result)  # 結果: tensor([3.4641, 3.4641])

この例では、各サブテンソルのFrobeniusノルムが計算されています。


3. 多態性の仕組み

torch.norm多態性を実現する方法は、テンソルの次元数や形状を動的に判別し、適切な計算アルゴリズムを適用することにあります。

以下は擬似コードです:

def norm(tensor, p=2, dim=None):
    if tensor.ndim == 1:  # 1次元(ベクトル)
        return vector_norm(tensor, p)
    elif tensor.ndim == 2:  # 2次元(行列)
        return matrix_norm(tensor, p)
    elif dim is not None:  # 高次元テンソル
        reduced_tensor = reduce_along_dims(tensor, dim)
        return vector_norm(reduced_tensor, p)
    else:  # デフォルトでFrobeniusノルム
        return frobenius_norm(tensor)

このようにして、torch.normは異なる入力タイプに適応した柔軟な計算を行っています。


4. 結論

PyTorchのtorch.normは、テンソルの次元や形状に基づいた多態性を実現することで、幅広い用途に対応する非常に便利な関数です。この多態性により、ユーザーは次元や形状を気にせず、テンソルの大きさを簡単に計算できます。

多次元データを扱う際にも、統一されたインターフェースで操作できるのはPyTorchの強みの一つです。ぜひtorch.normを活用してみてください。


この記事が参考になれば幸いです!