AI에 대해 이리저리 연구하다 보면 .safetensors 파일을 수도 없이 사용하게 되는데, 보안이 유지되는 회사 내부에서는 이미 학습된 딥러닝 모델을 huggingface로부터 가져오는 요청이 나가지 않거나, 또는 민감한 회사의 정보 자체를 클라우드 서비스에서 이용하기 꺼려하는 환경에 있다면 .safetensors 파일로 모델을 로컬에 저장해서 필요한 작업을 하게 된다. 나역시 .safetensors 형식으로 된 stable diffusion 모델을 열어서 ComfyUI나 webui 등의 오픈소스 도구를 사용하지 않고 모델로부터 이미지를 출력해보고 싶은 욕구가 생겼다. 바닥부터 한번 다져가보자.
.safetensors 파일이란
.safetensors 파일은 머신러닝 모델의 가중치(weight) 데이터를 안전하고 빠르게 저장하고 불러오기 위한 파일 형식이다. 이전의 .ckpt 형식에 비해 보안에 중점을 두고 있다고 말하는데, safetensors는 텍스트 기반 저장 형식과 달리 저장된 파일이 손상되거나 임의의 코드가 포함될 가능성을 줄여준다. .safetensors 파일의 구조는 바이너리 포맷으로 설계되어 있으며, 주요 목적은 머신러닝 모델의 가중치 데이터를 안전하게 저장하고 빠르게 불러오는 것이다. 이 형식은 특히 pickle의 취약성을 피하기 위해 설계되었으며, 모델 구조나 하이퍼파라미터와 같은 메타데이터를 포함하지 않고, 가중치 데이터만을 안전하게 저장한다.
.safetensors 파일의 세부 구조는 크게 헤더와 데이터 블록으로 구성된다. 헤더는 JSON형식으로 된 각 텐서의 메타데이터 정보이고, 데이터 블록은 실제 가중치 값들이 모여있는 바이너리 형태라고 보면 된다.
예시 헤더
{
"weight_1": {"dtype": "float32", "shape": [128, 256], "offsets": [0, 131072]},
"bias_1": {"dtype": "float32", "shape": [128], "offsets": [131072, 131584]},
...
}
헤더에는 텐서의 이름, 데이터 유형, 모양(크기), 데이터의 시작 위치와 크기 등이 기록된다. JSON 포맷을 보면, "dtype"은 데이터 타입(ex. float32, int64 등), "shape"는 텐서의 크기를 정의하는 배열, "offsets"는 텐서 데이터의 바이너리 블록에서의 시작과 끝 위치를 나타내는 오프셋 값이 된다. 예시 헤더에서 float32 데이터 타입의 가중치가 [128, 256] 차원으로 존재하기 때문에 4(byte)x128x256 =. 31,072로 실제 바이너리 데이터의 사이즈와 일치하는 것을 볼 수 있다.
python에서 .safetensors 파일 열어서 확인하기
Python의 safetensors 라이브러리를 이용해서 .safetensors 파일의 구조를 살펴볼 수 있다. 다음은 .safetensors 파일을 열고 내부 텐서의 메타데이터와 데이터를 확인하는 예시이다.
from safetensors import safe_open
# safetensors 파일 열기
safetensors_file = "path_to_model.safetensors"
with safe_open(safetensors_file, framework="pt") as f:
# 파일 내 텐서 이름 확인
tensor_names = f.keys()
print("텐서 목록:", tensor_names)
# 각 텐서의 메타데이터 출력
for name in tensor_names:
tensor = f.get_tensor(name)
print(f"텐서 '{name}'의 데이터 타입: {tensor.dtype}")
print(f"텐서 '{name}'의 크기: {tensor.shape}")
.safetensors 파일을 이용해서 nsfw_image_detection 모델 돌려보기
그럼 이제 갖고 있는 .safetensors 파일을 이용해서 딥러닝 모델을 로컬 환경에서 실행해보자. 본 포스팅에서는 이미지 분류 모델 중 하나인 nsfw_image_detection 모델을 예시로 가져왔다.
https://huggingface.co/Falconsai/nsfw_image_detection
Falconsai/nsfw_image_detection · Hugging Face
Model Card: Fine-Tuned Vision Transformer (ViT) for NSFW Image Classification Model Description The Fine-Tuned Vision Transformer (ViT) is a variant of the transformer encoder architecture, similar to BERT, that has been adapted for image classification ta
huggingface.co
NSFW(Not Safe For Work) Image Classification 모델은 입력받은 이미지가 safe 한지(보통은 19금을 걸러내는 게 주목적인것 같다.) 분류하는 모델이다. 많은 이미지 작업의 시작점에서 입력받은 이미지의 NSFW score를 검사해서 내가 정한 threshold를 넘기는지 아닌지 확인하는 모델이라 할 수 있겠다. 위의 huggingface 리포지토리를 clone하면 model.safetensors 파일과 config.json 파일이 존재한다. 세이프텐서 파일과 config 파일은 짝을 이루는 파일이라고 생각하면 편하다.
{
"_name_or_path": "Falconsai/nsfw_image_detection",
"architectures": [
"ViTForImageClassification"
],
"attention_probs_dropout_prob": 0.0,
"encoder_stride": 16,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 768,
"id2label": {
"0": "normal",
"1": "nsfw"
},
"image_size": 224,
"initializer_range": 0.02,
"intermediate_size": 3072,
"label2id": {
"normal": "0",
"nsfw": "1"
},
"layer_norm_eps": 1e-12,
"model_type": "vit",
"num_attention_heads": 12,
"num_channels": 3,
"num_hidden_layers": 12,
"patch_size": 16,
"problem_type": "single_label_classification",
"qkv_bias": true,
"torch_dtype": "float32",
"transformers_version": "4.31.0"
}
config.json 파일을 보면, transformers.ViTForImageClassification 아키텍쳐를 이용해서 해당 모델을 불러와 사용할 수 있음을 알 수 있다. 파이썬 코드를 아래와 같이 작성해서 실행시켜 보자.
# Load model directly
import torch
from PIL import Image
from transformers import ViTForImageClassification, ViTImageProcessor
# 모델과 전처리기 로드
model = ViTForImageClassification.from_pretrained("./", use_safetensors=True)
processor = ViTImageProcessor.from_pretrained("./")
# 입력 이미지 로드 및 전처리
input_image = Image.open("./test.png")
inputs = processor(images=input_image, return_tensors="pt")
# nsfw score 예측
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# 결과 확인
predicted_label = torch.argmax(logits, dim=-1).item()
print(f"Predicted label: {model.config.id2label[predicted_label]}")
이렇게 코드를 작성하고 돌려보면 내가 입력한 이미지에 대한 label을 확인할 수 있다.