SAM 2 para segmentação
Como mostramos recentemente (aqui), SAM 2 (Segment Anything Model 2) é o novo modelo de IA do Facebook/Meta para segmentação de objetos em imagens e vídeos. Ele é o primeiro modelo unificado capaz de realizar segmentação com prompts em imagens estáticas e vídeos em tempo real. Segundo seus autores, o novo modelo pode segmentar qualquer objeto em qualquer vídeo ou imagem. Neste post, mostraremos como usar SAM 2 para segmentação de imagens estáticas em Python.
Instalações Requeridas
Segundo a documentação, a execução de SAM 2 necessita de GPU. Se seu computador não possui GPU, execute o modelo no Colab.
A segmentação de imagens que realizaremos com SAM 2 requer a instalação dos seguintes pacotes: NumPy, Matplotlib, Pillow, OpenCV, PyTorch e TorchVision. Idealmente, crie um ambiente virtual novo antes de instalar os pacotes necessários. A instalação dos quatro primeiros pacotes pode ser feita com pip:
pip install matplotlib numpy opencv-python pillow
As instalações do PyTorch e TorchVision variam conforme o seu sistema. Portanto, consulte a documentação oficial (aqui) para obter informações para o seu caso.
Após realizar as instalações, clone o repositório do SAM 2:
git clone https://github.com/facebookresearch/segment-anything-2.git
Atenção: para quem usa Windows, é fortemente recomendado instalar Windows Subsystem for Linux (WSL) com Ubuntu. Veja mais detalhes aqui.
Após clonar o repositório, entre nele e instale os requerimentos finais:
cd segment-anything-2
pip install -e .
Para terminar, é preciso baixar um arquivo. yaml e os pesos do modelo que utilizaremos. Existem quatro tamanhos de modelo. Você pode obter todos eles na Hugging Face ou diretamente no repositório do SAM 2. Usaremos o modelo de menor tamanho. Ele se chama tiny. Para baixar os arquivos pela Hugging Face, acesse essa página e baixe o arquivo sam2_hiera_t.yaml e os pesos do modelo (sam2_hiera_tiny.pt). O arquivo .yaml precisa ser colocado no diretório do SAM 2 clonado do Git. O arquivo de pesos (.pt) deve ser colocado num diretório chamado checkpoints no diretório clonada do Git.
Página da Hugging Face com recursos do SAM 2.
Importações e Configurações iniciais
Com tudo instalado, vamos ao código. Ele é fortemente baseado nesse notebook. Crie um código Python com o nome que desejar. Idealmente, o crie diretamente no repositório do SAM 2 clonado do GitHub. Se você preferir não fazer isso, defina corretamente o caminho para os pesos e arquivo .yaml do modelo no trecho abaixo (linhas 10 e 11 abaixo). No início do código, fazemos as importações e as definições das configurações do modelo:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import torch
import torchvision
# seleciona os recursos do modelo
sam2_checkpoint = "checkpoints/sam2_hiera_tiny.pt" # mude o caminho aqui se precisar
model_cfg = "sam2_hiera_t.yaml" # mude o caminho aqui se precisar
Em seguida, é preciso verificar a presença de GPU. Não é totalmente garantido, mas aparentemente SAM 2 não roda sem GPU. Também é importante definir algumas configurações para o PyTorch. Usaremos torch.float16, mas você pode usar torch.bfloat16 se preferir:
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print(f"using device: {device}")
if device.type == "cuda":
torch.autocast("cuda", dtype=torch.float16).__enter__()
Funções de auxílio
Em seguida, definiremos algumas funções de auxílio. Elas foram retiradas integralmente daqui. No total, usaremos quatro funções de auxílio. Duas são utilizadas no código para especificar uma caixa ou ponto sobre o objeto onde a segmentação será feita:
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
As outras duas definem máscaras:
def show_mask(mask, ax, random_color=False, borders = True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
if borders:
import cv2
contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
ax.imshow(mask_image)
def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca(), borders=borders)
if point_coords is not None:
assert input_labels is not None
show_points(point_coords, input_labels, plt.gca())
if box_coords is not None:
# boxes
show_box(box_coords, plt.gca())
if len(scores) > 1:
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
A imagem
img = Image.open('path para sua imagem') # formatos: .png, .jpg, .jpeg
A imagem precisa ser convertida em um NumPy array. Ela pode ser plotada com Matplotlib.
# converte imagem em array
image = np.array(img.convert("RGB"))
# plota imagem
plt.figure(figsize=(10, 10))
plt.imshow(img)
plt.axis('on')
plt.show()
A vantagem de plotar a imagem com Matplotlib é obter os eixos do array com suas dimensões (isso não é um plano cartesiano xy). A partir desses valores, podemos selecionar pontos ou caixas para identificar ou delimitar os objetos que queremos segmentar. SAM 2 realiza segmentações com o auxílio de prompts. Esses prompts consistem em listas de valores numéricos. Veja nossa imagem abaixo:
Definição do modelo e marcação do objeto
Para realizar a segmentação em si, é preciso iniciar o modelo:
# cria o objeto do modelo
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
# define o preditor
predictor = SAM2ImagePredictor(sam2_model)
predictor.set_image(image)
Em seguida, é necessário marcar/delimitar o objeto que será segmentado. Essa marcação corresponde ao prompt que deve ser fornecido ao SAM 2. Ela pode ser feita usando um ponto ou uma caixa delimitadora. Para esse exemplo, um ponto será utilizado:
# coordenadas de pontos para identificar o objeto para segmentar
input_point = np.array([[300, 175]])
input_label = np.array([1])
# plota imagem com ponto marcado
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()
O ponto na imagem é mostrado com uma estrela:
Segmentação com SAM 2
A segmentação é feita a seguir:
# realiza a segmentação
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]
# exibe as imagems com segmentação
show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)
Os resultados retornam três segmentações com seus scores. Veja as segmentações obtidas abaixo:
Como é possível observar, os resultados são ótimos e não dependem de rótulos. Suas qualidades surpreendem ainda mais quando consideramos que o modelo usado foi o menor disponível. Para aprimorar ainda mais a qualidade dos resultados, uma alternativa adicional é fornecer uma lista de pontos ao invés de apenas um. SAM 2 também aceita vários pontos para segmentar múltiplos objetos.
Segmentação com caixas delimitadoras
Para segmentar usando uma caixa delimitadora, o processo é parecido. Portanto, faremos apenas uma continuação no código anterior. Utilizaremos uma nova imagem:
# abre nova imagem
img2 = Image.open('local da sua imagem')
# converte imagem em array
image2 = np.array(img2.convert("RGB"))
A imagem é mostrada abaixo:
Assim como foi feito com a determinação de um ponto, a caixa delimitadora precisa ser criada usando coordenadas que variam conforme a imagem e o objeto que se deseja segmentar. As coordenadas que utilizaremos irão delimitar o rosto da moça do meio:
input_box = np.array([150, 0, 350, 250])
Depois, basta chamar a função de segmentação:
masks, scores, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
show_masks(image2, masks, scores, box_coords=input_box)
E esse é o resultado:
Conclusão
Neste post, usamos SAM 2 para realizar segmentações em imagens estáticas. Não utilizamos vídeos para não nos estendermos demais. Mas a segmentação em vídeos é parecida. Além do procedimento que utilizamos, SAM 2 realiza segmentações com geração automática de máscaras. Para conhecer mais sobre ele, consulte a página oficial aqui.