Google USM: Extracted Gemma-3n Audio Encoder (USM)

このモデルの実態は不明確です。Introducing Gemma 3n: The developer guideには、 USMに基づくエンコーダーが使用されていると記述されていますが、USMの論文とこのモデルにはいくつかの異なる点が存在します。 このモデルは0.6Bですが、USMの論文の0.6Bモデルとは層の数が異なります。 このモデルは Gemma 3n の AudioEncoder であり、本来の USM とは異なる可能性があります。

Model Description

このモデルは、Googleのマルチモーダルモデル google/gemma-3n-e2b-it から、音声エンコーダー部分 (audio_tower) のみを抽出したものです。

bf16版:https://huggingface.co/Atotti/google-usm-bf16

アーキテクチャは、論文 Universal Speech Model に基づくGemma3nAudioEncoderです。

このエンコーダーは、音声波形データを受け取り、その内容を表現する高次元の特徴量(エンコーディング)のシーケンスに変換する役割を果たします。

Intended Use

このモデルは単体で音声認識(文字起こし)などを行うものではなく、より大きなモデルのコンポーネントとして使用されることを想定しています。

  • マルチモーダルモデルの音声入力部として: 生成AIに音声情報を与えるための特徴量を抽出します。
  • 音声分類: このモデルの出力に分類ヘッドを追加して、特定の音声(例:笑い声、拍手、特定の単語)を分類するタスクでファインチューニングします。
  • 音声類似度検索: 音声のエンコーディングをベクトルとして扱い、意味的に似た音声を検索します。
  • 話者認識: 音声から話者を識別するタスクのベースモデルとして利用します。

How to Use

import torch
import soundfile as sf
from transformers import Gemma3nAudioEncoder, Gemma3nAudioFeatureExtractor

encoder_id = "Atotti/google-usm"
source_model_id = "google/gemma-3n-e2b-it"

audio_encoder = Gemma3nAudioEncoder.from_pretrained(encoder_id)
feature_extractor = Gemma3nAudioFeatureExtractor.from_pretrained(source_model_id)

device = "cuda" if torch.cuda.is_available() else "cpu"
audio_encoder.to(device)
audio_encoder.eval()

waveform, sampling_rate = sf.read("/path/to/your_audio_file.wav")


inputs = feature_extractor(
    [waveform],
    sampling_rate=sampling_rate,
    return_tensors="pt"
)

audio_mel = inputs["input_features"].to(device)
audio_mel_mask = (inputs["input_features_mask"] == 0).to(device)

with torch.inference_mode():

    audio_encodings, output_mask = audio_encoder(
        audio_mel=audio_mel,
        audio_mel_mask=audio_mel_mask
    )

print(audio_encodings.shape) # torch.Size([1, 18, 1536])
print(audio_encodings[0, :5, :10])
# tensor([[ 0.0014, -0.0044,  0.0003,  0.0084, -0.0076, -0.0194,  0.0071,  0.0160,
#           0.0137,  0.0146],
#         [-0.0153,  0.0051,  0.0111, -0.0134, -0.0032, -0.0134,  0.0112, -0.0163,
#           0.0050,  0.0036],
#         [ 0.0003, -0.0022,  0.0164, -0.0090, -0.0033, -0.0043,  0.0030, -0.0042,
#          -0.0060,  0.0066],
#         [-0.0006, -0.0194, -0.0006, -0.0097, -0.0049, -0.0132,  0.0012,  0.0175,
#          -0.0242, -0.0091],
#         [ 0.0127,  0.0122,  0.0125,  0.0277,  0.0116,  0.0152,  0.0142, -0.0099,
#          -0.0080, -0.0233]], device='cuda:0')

Model Architecture

Gemma3nAudioEncoder(
  (subsample_conv_projection): Gemma3nAudioSubSampleConvProjection(
    (conv_0): Gemma3nAudioSSCPConvBlock(
      (conv): Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (norm): Gemma3nAudioCumulativeGroupNorm()
      (activation): ReLU()
    )
    (conv_1): Gemma3nAudioSSCPConvBlock(
      (conv): Conv2d(128, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (norm): Gemma3nAudioCumulativeGroupNorm()
      (activation): ReLU()
    )
    (input_proj_linear): Linear(in_features=1024, out_features=1536, bias=False)
  )
  (conformer): ModuleList(
    (0-11): 12 x Gemma3nAudioConformerBlock(
      (ffw_layer_start): Gemma3nAudioConformerFeedForward(
        (pre_layer_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
        (ffw_layer_1): Linear(in_features=1536, out_features=6144, bias=False)
        (ffw_layer_2): Linear(in_features=6144, out_features=1536, bias=False)
        (post_layer_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
      )
      (attention): Gemma3nAudioConformerAttention(
        (pre_attn_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
        (attn): Gemma3nAudioAttention(
          (relative_position_embedding): Gemma3nAudioRelativePositionEmbedding(
            (pos_proj): Linear(in_features=1536, out_features=1536, bias=False)
          )
          (q_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (k_proj): Linear(in_features=1536, out_features=1536, bias=False)
          (v_proj): Linear(in_features=1536, out_features=1536, bias=False)
        )
        (post): Linear(in_features=1536, out_features=1536, bias=False)
        (post_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
      )
      (lconv1d): Gemma3nAudioConformerLightConv1d(
        (pre_layer_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
        (linear_start): Linear(in_features=1536, out_features=3072, bias=False)
        (depthwise_conv1d): Conv1d(1536, 1536, kernel_size=(5,), stride=(1,), groups=1536, bias=False)
        (conv_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
        (linear_end): Linear(in_features=1536, out_features=1536, bias=False)
      )
      (ffw_layer_end): Gemma3nAudioConformerFeedForward(
        (pre_layer_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
        (ffw_layer_1): Linear(in_features=1536, out_features=6144, bias=False)
        (ffw_layer_2): Linear(in_features=6144, out_features=1536, bias=False)
        (post_layer_norm): Gemma3nRMSNorm((1536,), eps=1e-06)
      )
      (norm): Gemma3nRMSNorm((1536,), eps=1e-06)
    )
  )
)
Downloads last month
1,178
Safetensors
Model size
681M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support