PyTorch
ssl-aasist
custom_code
File size: 389 Bytes
fb0facd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/**
 * Copyright 2017-present, Facebook, Inc.
 * All rights reserved.
 *
 * This source code is licensed under the license found in the
 * LICENSE file in the root directory of this source tree.
 */

#pragma once

#include <torch/extension.h> // @manual=//caffe2:torch_extension

void alignmentTrainCUDAWrapper(
    const torch::Tensor& p_choose,
    torch::Tensor& alpha,
    float eps);