7.3 KiB
7.3 KiB
pittrap - misc
The challenge contains a onnx file for a neural network.
The input layer accepts a vector of length 48, with vocab_size 256. This indicates the input is a 48-character string, which is tokenized by [ord(i) for i in inp]. The output is a single score.
Code to apply the network to an input
import onnx
onnx_model = onnx.load(ONNX_FILE_PATH)
onnx.checker.check_model(onnx_model)
print(onnx_model)
inp = None
# Write code here to apply the model on the input
# out = ???
import numpy as np
from onnx.reference import ReferenceEvaluator
# int64 token ids, shape (batch_size, 48); input name must be "input_ids"
inp = np.zeros((1, 48), dtype=np.int64)
sess = ReferenceEvaluator(onnx_model)
out = sess.run(None, {"input_ids": inp})
score = out[0]
print(score)
Then I tried a gradient ascent approach to find the input the leads to the maximum score. The input was the flag.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from onnx import numpy_helper
from onnx.reference import ReferenceEvaluator
# Target format: gigem{just_<middle>_max_and_u'll_be_fine}
PREFIX = "gigem{"
SUFFIX = "}"
ALPHABET = "abcdefghijklmnopqrstuvwxyz0123456789_"
# Only the middle is unknown. Restricting it to a plausible alphabet avoids
# the optimizer getting stuck in high-scoring non-ASCII byte values.
n_inner = 48 - len(PREFIX) - len(SUFFIX)
n_pad = 48 - (len(PREFIX) + n_inner + len(SUFFIX))
assert n_pad >= 0 and n_inner > 0
weights = {init.name: numpy_helper.to_array(init).copy() for init in onnx_model.graph.initializer}
class GigemTorch(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("embed_mat", torch.from_numpy(weights["embed.weight"]))
self.conv = nn.Conv1d(64, 128, kernel_size=3, padding=2, dilation=2)
self.fc1 = nn.Linear(6144, 64)
self.fc2 = nn.Linear(64, 1)
self.conv.weight.data.copy_(torch.from_numpy(weights["conv.weight"]))
self.conv.bias.data.copy_(torch.from_numpy(weights["conv.bias"]))
self.fc1.weight.data.copy_(torch.from_numpy(weights["fc1.weight"]))
self.fc1.bias.data.copy_(torch.from_numpy(weights["fc1.bias"]))
self.fc2.weight.data.copy_(torch.from_numpy(weights["fc2.weight"]))
self.fc2.bias.data.copy_(torch.from_numpy(weights["fc2.bias"]))
def forward_soft(self, probs):
x = torch.einsum("blv,vh->blh", probs, self.embed_mat)
x = x.permute(0, 2, 1)
x = self.conv(x)
x = F.gelu(x)
x = x.flatten(1)
x = self.fc1(x)
x = F.gelu(x)
x = self.fc2(x)
return x.squeeze(-1)
def chars_to_onehot_row(device, c):
v = ord(c) & 0xFF
return F.one_hot(torch.tensor([v], device=device, dtype=torch.long), 256).float()
def build_probs_from_middle_allowed(middle_allowed_probs, allowed_ids, device):
batch_size = middle_allowed_probs.shape[0]
middle = torch.zeros(batch_size, n_inner, 256, device=device)
middle.scatter_(2, allowed_ids.view(1, 1, -1).expand(batch_size, n_inner, -1), middle_allowed_probs)
blocks = []
if n_pad > 0:
pad = torch.zeros(batch_size, n_pad, 256, device=device)
pad[:, :, 0] = 1.0
blocks.append(pad)
for c in PREFIX:
blocks.append(chars_to_onehot_row(device, c).unsqueeze(0).expand(batch_size, -1, -1))
blocks.append(middle)
for c in SUFFIX:
blocks.append(chars_to_onehot_row(device, c).unsqueeze(0).expand(batch_size, -1, -1))
return torch.cat(blocks, dim=1)
def ids_from_middle_indices(middle_indices, allowed_ids):
middle_ids = allowed_ids[middle_indices].detach().cpu().numpy()[0]
token_ids = [0] * n_pad
token_ids.extend(ord(c) & 0xFF for c in PREFIX)
token_ids.extend(int(x) for x in middle_ids)
token_ids.extend(ord(c) & 0xFF for c in SUFFIX)
return np.array(token_ids, dtype=np.int64)
def score_ids(model, token_ids, device):
token_tensor = torch.tensor(token_ids, dtype=torch.long, device=device).unsqueeze(0)
one_hot = F.one_hot(token_tensor, 256).float()
return float(model.forward_soft(one_hot).item())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GigemTorch().to(device)
allowed_ids = torch.tensor([ord(c) for c in ALPHABET], dtype=torch.long, device=device)
restarts = 6
steps = 1000
best_score = float("-inf")
best_ids = None
for restart in range(restarts):
logits_inner = torch.randn(1, n_inner, len(ALPHABET), device=device) * 0.01
logits_inner.requires_grad_(True)
opt = torch.optim.Adam([logits_inner], lr=0.2)
for step in range(steps):
tau = max(0.25, 2.5 * (0.992 ** step))
probs_inner = F.softmax(logits_inner / tau, dim=-1)
probs = build_probs_from_middle_allowed(probs_inner, allowed_ids, device)
soft_score = model.forward_soft(probs)
entropy = -(probs_inner * probs_inner.clamp_min(1e-9).log()).sum(dim=-1).mean()
loss = -soft_score + 0.02 * tau * entropy
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_([logits_inner], max_norm=5.0)
opt.step()
if step % 100 == 0 or step == steps - 1:
with torch.no_grad():
hard_idx = logits_inner.argmax(dim=-1)
token_ids = ids_from_middle_indices(hard_idx, allowed_ids)
hard_score = score_ids(model, token_ids, device)
guess_str = "".join(chr(int(t)) for t in token_ids)
if hard_score > best_score:
best_score = hard_score
best_ids = token_ids.copy()
print(
f"restart {restart} step {step:4d} tau={tau:.3f} soft={soft_score.item():.4f} discrete={hard_score:.4f} guess={guess_str!r}"
)
# Greedy coordinate refinement over the discrete candidate.
# This fixes the usual softmax-relaxation issue where argmax is close but not exact.
charset_ids = [ord(c) for c in ALPHABET]
for refine_round in range(10):
improved = False
for pos in range(n_pad + len(PREFIX), n_pad + len(PREFIX) + n_inner):
current = best_ids[pos]
local_best = best_score
local_char = current
for cand in charset_ids:
if cand == current:
continue
trial = best_ids.copy()
trial[pos] = cand
trial_score = score_ids(model, trial, device)
if trial_score > local_best:
local_best = trial_score
local_char = cand
if local_char != current:
best_ids[pos] = local_char
best_score = local_best
improved = True
print(f"refine round {refine_round}: score={best_score:.4f} guess={''.join(chr(int(t)) for t in best_ids)!r}")
if not improved:
break
middle = "".join(
chr(int(i))
for i in best_ids[n_pad + len(PREFIX) : n_pad + len(PREFIX) + n_inner]
)
flag_only = PREFIX + middle + SUFFIX
padded_visual = "".join(chr(int(i)) for i in best_ids)
print("n_pad, n_inner:", n_pad, n_inner)
print("middle:", repr(middle))
print("flag:", repr(flag_only))
print("full 48 (repr):", repr(padded_visual))
sess = ReferenceEvaluator(onnx_model)
onnx_score = sess.run(None, {"input_ids": best_ids.reshape(1, -1).astype(np.int64)})[0]
print("ONNX score (discrete):", onnx_score)