Spaces:
Build error
Build error
void vecquant2matmul_cuda( | |
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, | |
torch::Tensor scales, torch::Tensor zeros, | |
int groupsize | |
); | |
void vecquant2matmul( | |
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, | |
torch::Tensor scales, torch::Tensor zeros, | |
int groupsize | |
) { | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); | |
vecquant2matmul_cuda(vec, mat, mul, scales, zeros,groupsize); | |
} | |
void vecquant3matmul_cuda( | |
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, | |
torch::Tensor scales, torch::Tensor zeros, | |
int groupsize | |
); | |
void vecquant3matmul( | |
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, | |
torch::Tensor scales, torch::Tensor zeros, | |
int groupsize | |
) { | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); | |
vecquant3matmul_cuda(vec, mat, mul, scales, zeros, groupsize); | |
} | |
void vecquant4matmul_cuda( | |
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, | |
torch::Tensor scales, torch::Tensor zeros, | |
int groupsize | |
); | |
void vecquant4matmul( | |
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, | |
torch::Tensor scales, torch::Tensor zeros, | |
int groupsize | |
) { | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); | |
vecquant4matmul_cuda(vec, mat, mul, scales, zeros, groupsize); | |
} | |
void vecquant8matmul_cuda( | |
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, | |
torch::Tensor scales, torch::Tensor zeros, | |
int groupsize | |
); | |
void vecquant8matmul( | |
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, | |
torch::Tensor scales, torch::Tensor zeros, | |
int groupsize | |
) { | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); | |
vecquant8matmul_cuda(vec, mat, mul, scales, zeros, groupsize); | |
} | |
void vecquant2matmul_faster_cuda( | |
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, | |
torch::Tensor scales, torch::Tensor zeros, | |
int groupsize, int vec_height | |
); | |
void vecquant2matmul_faster( | |
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, | |
torch::Tensor scales, torch::Tensor zeros, | |
int groupsize, int vec_height | |
) { | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); | |
vecquant2matmul_faster_cuda(vec, mat, mul, scales, zeros, groupsize, vec_height); | |
} | |
void vecquant3matmul_faster_cuda( | |
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, | |
torch::Tensor scales, torch::Tensor zeros, | |
int groupsize, int vec_height | |
); | |
void vecquant3matmul_faster( | |
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, | |
torch::Tensor scales, torch::Tensor zeros, | |
int groupsize, int vec_height | |
) { | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); | |
vecquant3matmul_faster_cuda(vec, mat, mul, scales, zeros, groupsize, vec_height); | |
} | |
void vecquant4matmul_faster_cuda( | |
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, | |
torch::Tensor scales, torch::Tensor zeros, | |
int groupsize, int vec_height | |
); | |
void vecquant4matmul_faster( | |
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, | |
torch::Tensor scales, torch::Tensor zeros, | |
int groupsize, int vec_height | |
) { | |
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); | |
vecquant4matmul_faster_cuda(vec, mat, mul, scales, zeros, groupsize, vec_height); | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)"); | |
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)"); | |
m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)"); | |
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)"); | |
m.def("vecquant2matmul_faster", &vecquant2matmul_faster, "Vector 4-bit Quantized Matrix Multiplication (CUDA), faster version"); | |
m.def("vecquant3matmul_faster", &vecquant3matmul_faster, "Vector 3-bit Quantized Matrix Multiplication (CUDA), faster version"); | |
m.def("vecquant4matmul_faster", &vecquant4matmul_faster, "Vector 4-bit Quantized Matrix Multiplication (CUDA), faster version"); | |
} | |