diff --git a/shifthappens/tasks/lost_in_translation/affine_transformations/__init__.py b/shifthappens/tasks/lost_in_translation/affine_transformations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/shifthappens/tasks/lost_in_translation/affine_transformations/affine.py b/shifthappens/tasks/lost_in_translation/affine_transformations/affine.py new file mode 100644 index 00000000..9515a34f --- /dev/null +++ b/shifthappens/tasks/lost_in_translation/affine_transformations/affine.py @@ -0,0 +1,717 @@ +from pickle import TRUE +import torch +import numpy as np +from skimage.transform import resize +import skimage.io as io +import kornia.geometry.transform as t +import torch.nn as nn +import random +import matplotlib.pyplot as plt +import torch.functional as F + +config_imagenet = { + 'target_size': 254, + 'crop_size': 254, + #'add_space_factor': 1 #left,right,upper,lower +} + +config = config_imagenet + + +def square_dim(old_start, old_length, new_length, space, oversized_allowed=False): + space_left = old_start + space_right = space - (old_start + old_length) + half_space_needed_float = (new_length - old_length)/2.0 + half_space_needed = int(half_space_needed_float) + assert half_space_needed >= 0 + + if space_left > half_space_needed_float and space_right > half_space_needed_float: + new_start = old_start - half_space_needed + elif space_left < half_space_needed_float and space_right > (half_space_needed_float + half_space_needed - space_left): + new_start = 0 + elif space_left > (half_space_needed_float + half_space_needed_float - space_right) and space_right < half_space_needed_float: + new_start = old_start - ((new_length - old_length) - space_right) + else: + if not oversized_allowed: + raise Exception("unable to compute square") + else: + new_start = old_start - half_space_needed + return new_start, new_length + +def square_bbox(coco_bbox, h, w, oversized_allowed=False): + bbox_x, bbox_y, bbox_w, bbox_h = map(int,coco_bbox) + if bbox_w > bbox_h: + bbox_y, bbox_h = square_dim(bbox_y, bbox_h, bbox_w, h, oversized_allowed=oversized_allowed) + elif bbox_w < bbox_h: + bbox_x, bbox_w = square_dim(bbox_x, bbox_w, bbox_h, w, oversized_allowed=oversized_allowed) + + return int(bbox_x), int(bbox_w), int(bbox_y), int(bbox_h) + +def get_bounds(img, do_check=True): + rows = np.any(img, axis=1) + cols = np.any(img, axis=0) + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + + if do_check: + assert (rmax - rmin) == (cmax - cmin) + + return rmin, rmax, cmin, cmax + +def load_and_pad_imagenet(I: np.ndarray, gt_uint: np.ndarray, in_label, load_other_masks_fun=None, accept_smaller=True): + m = np.where(gt_uint == in_label) + + bbox_y = np.min(m[0]) + bbox_h = np.max(m[0]) - np.min(m[0]) + bbox_x = np.min(m[1]) + bbox_w = np.max(m[1]) - np.min(m[1]) + bbox = (bbox_x, bbox_y, bbox_w, bbox_h) + I_mask = (gt_uint == in_label).astype(int) + return load_and_pad(I, I_mask, bbox, load_other_masks_fun=load_other_masks_fun, accept_smaller=accept_smaller) + +def find_enclosing_crop(bbox_item, h, w, crop_size): + (bbox_x, bbox_y, bbox_w, bbox_h) = bbox_item + def calc_bounds_side(bb_start, bb_l, length): + middle = bb_start + int(0.5*bb_l) + half_crop_l = int(0.5*crop_size) + half_crop_r = crop_size - int(0.5*crop_size) #if not div by 2 + space_needed_left = max(half_crop_l-middle, 0) + available_l = length - 1 + space_needed_right = max((middle + half_crop_r) - available_l, 0) + start_left = max(middle - half_crop_l, 0) - space_needed_right + if start_left < 0: + raise Exception("unable to crop element") + end_right = min(middle + half_crop_r, available_l) + space_needed_left + if end_right > length: + raise Exception("unable to crop element") + return start_left, end_right + crop_start_y, crop_end_y = calc_bounds_side(bbox_y, bbox_h, h) + crop_start_x, crop_end_x = calc_bounds_side(bbox_x, bbox_w, w) + return crop_start_y, crop_end_y, crop_start_x, crop_end_x + + +def load_and_pad(I: np.ndarray, I_mask: np.ndarray, bbox, load_other_masks_fun=None, accept_smaller=False): + crop_size = config['crop_size'] + h,w = I.shape[:2] + + if min(h,w) < crop_size: + raise Exception("not possible") + + (bbox_x, bbox_y, bbox_w, bbox_h) = bbox + + + sub = I[bbox_y:(bbox_y + bbox_h),bbox_x:(bbox_x + bbox_w)] + if load_other_masks_fun is not None: + other_masks = load_other_masks_fun() + I_mask = I_mask + other_masks + tup = (bbox_x, bbox_y, bbox_w, bbox_h) + ratio = float(crop_size) / max(bbox_h, bbox_w ) + I_bounding = np.zeros((I.shape[0], I.shape[1])) + I_bounding[bbox_y:(bbox_y + bbox_h),bbox_x:(bbox_x + bbox_w)] = 1. + if ratio < 1.0: + sq_x, sq_w, sq_y, sq_h = square_bbox(tup, h, w) + I_crop = np.zeros((I.shape[0], I.shape[1])) + I_crop[sq_y:(sq_y + sq_h),sq_x:(sq_x + sq_w)] = 1. + + start_crop_x= max(sq_x - sq_w, 0) + end_crop_x = min(sq_x + sq_w + sq_w, I.shape[1]) + start_crop_y = max(sq_y - sq_h, 0) + end_crop_y = min(sq_y + sq_h + sq_h, I.shape[0]) + + sub = I[start_crop_y:end_crop_y,start_crop_x:end_crop_x] + sub_mask = I_mask[start_crop_y:end_crop_y,start_crop_x:end_crop_x] + sub_bounding = I_bounding[start_crop_y:end_crop_y,start_crop_x:end_crop_x] + sub_crop = I_crop[start_crop_y:end_crop_y,start_crop_x:end_crop_x] + + target_size = (int(sub.shape[0]*ratio),int(sub.shape[1]*ratio)) + sub_small = resize(sub, target_size, anti_aliasing=True) + sub_mask_small = resize(sub_mask.astype(float), target_size, anti_aliasing=True) + sub_bounding_small = resize(sub_bounding.astype(float), target_size, anti_aliasing=True) + sub_crop_small = resize(sub_crop.astype(float), target_size, anti_aliasing=True) + #correct it + rmin, rmax, cmin, cmax = get_bounds(sub_crop_small, do_check=False) + row_length = (rmax + 1) - rmin + col_length = (cmax + 1) - cmin + #print(f"row_length: {row_length}, col_length: {col_length}") + + sub_crop_small[rmin:(rmin + crop_size), cmin:(cmin + crop_size)] = 1.0 + sub_crop_small[(rmin + crop_size):] = 0.0 + sub_crop_small[:,(cmin + crop_size):] = 0.0 + rmin, rmax, cmin, cmax = get_bounds(sub_crop_small, do_check=False) + row_length = (rmax + 1) - rmin + col_length = (cmax + 1) - cmin + #assert row_length == crop_size and col_length == crop_size + + rmin, rmax, cmin, cmax = get_bounds(sub_crop_small, do_check=False) + + before_y = max(crop_size - rmin, 0) + after_y = max(crop_size*3 - (sub_crop_small.shape[0] + before_y), 0) + before_x = max(crop_size - cmin, 0) + after_x = max(crop_size*3 - (sub_crop_small.shape[1] + before_x), 0) + + to_pad = ((before_y, after_y), (before_x, after_x)) + + sub_padded = np.pad(sub_small, list(to_pad) + [(0,0)], mode='constant') + sub_mask_padded = np.pad(sub_mask_small, to_pad, mode='constant') + sub_bounding_padded = np.pad(sub_bounding_small, to_pad, mode='constant') + sub_crop_padded = np.pad(sub_crop_small, to_pad, mode='constant') + + length = 3*crop_size + + sub_padded = sub_padded[:length,:length] + assert sub_padded.shape[0] == length and sub_padded.shape[1] == length + sub_mask_padded = sub_mask_padded[:length,:length] + sub_bounding_padded = sub_bounding_padded[:length,:length] + sub_crop_padded = sub_crop_padded[:length,:length] + else: + assert accept_smaller + crop_start_y, crop_end_y, crop_start_x, crop_end_x = find_enclosing_crop(tup, h, w, crop_size) + assert (crop_end_y - crop_start_y) == crop_size + assert (crop_end_x - crop_start_x) == crop_size + temp = I[crop_start_y:(crop_end_y), crop_start_x:(crop_end_x)] + assert temp.shape[0] == crop_size and temp.shape[1] == crop_size + start_bound_y = max(crop_start_y - crop_size, 0) + end_bound_y = min(crop_end_y + crop_size, h) + start_bound_x = max(crop_start_x - crop_size, 0) + end_bound_x = min(crop_end_x + crop_size, w) + + do_crop = lambda arr: arr[start_bound_y:(end_bound_y), start_bound_x:(end_bound_x)] + sub = do_crop(I).astype(float)/255. + sub_mask = do_crop(I_mask).astype(float) + sub_bounding = do_crop(I_bounding).astype(float) + I_crop = np.zeros_like(I_bounding) + I_crop[crop_start_y:(crop_end_y), crop_start_x:(crop_end_x)] = 1.0 + sub_crop = do_crop(I_crop).astype(float) + + before_y = max(crop_size - crop_start_y, 0) + after_y = max(crop_size*3 - (sub_crop.shape[0] + before_y), 0) + before_x = max(crop_size - crop_start_x, 0) + after_x = max(crop_size*3 - (sub_crop.shape[1] + before_x), 0) + + to_pad = ((before_y, after_y), (before_x, after_x)) + + sub_padded = np.pad(sub, list(to_pad) + [(0,0)], mode='constant') + sub_mask_padded = np.pad(sub_mask, to_pad, mode='constant') + sub_bounding_padded = np.pad(sub_bounding, to_pad, mode='constant') + sub_crop_padded = np.pad(sub_crop, to_pad, mode='constant') + + length = 3*crop_size + + sub_padded = sub_padded[:length,:length] + assert sub_padded.shape[0] == length and sub_padded.shape[1] == length + sub_mask_padded = sub_mask_padded[:length,:length] + sub_bounding_padded = sub_bounding_padded[:length,:length] + sub_crop_padded = sub_crop_padded[:length,:length] + + #idea: 1st: try find square crop + #2nd: no resizing needed, but pad + #finished + import math + if not math.isclose(sub_mask_padded.max(), 1.0, rel_tol=1e-05, abs_tol=1e-08): + raise Exception("not enough mask! We are scaling it too small") + #assert math.isclose(sub_mask_padded.max(), 1.0, rel_tol=1e-01) + #to coordinates again: + b_rmin, b_rmax, b_cmin, b_cmax = get_bounds(sub_bounding_padded, do_check=False) + #correct it...sometimes it get's off by a few pixels if we are scaling a lot + def correct_b(b_min, b_max): + leng = (b_max + 1) - b_min + diff = max(leng - crop_size,0) + assert diff <= 10 and diff >= 0 + return b_max - diff + b_rmax = correct_b(b_rmin, b_rmax) + b_cmax = correct_b(b_cmin, b_cmax) + + temp = sub_bounding_padded[b_rmin:(b_rmax + 1), b_cmin:(b_cmax + 1)] + assert temp.shape[0] <= crop_size and temp.shape[1] <= crop_size + bound_coords = ((b_rmin, b_rmax), (b_cmin, b_cmax)) + + c_rmin, c_rmax, c_cmin, c_cmax = get_bounds(sub_crop_padded, do_check=False) + def correct_c(c_min, c_max): + leng = (c_max + 1) - c_min + diff = leng - crop_size + return c_max - diff + c_rmax = correct_c(c_rmin, c_rmax) + c_cmax = correct_c(c_cmin, c_cmax) + temp = sub_crop_padded[c_rmin:(c_rmax + 1), c_cmin:(c_cmax + 1)] + assert temp.shape[0] == crop_size and temp.shape[1] == crop_size + crop_coords = ((c_rmin, c_rmax), (c_cmin, c_cmax)) + + start_end_coord = ((before_y, crop_size*3-after_y), (before_x, crop_size*3-after_x)) + + return sub_padded, sub_mask_padded, bound_coords, crop_coords, start_end_coord + + +def to_torch(img, mask, bound_coord, crop_coord, start_end_coord): + t_img = torch.from_numpy(img).to(torch.float32).permute(2,0,1).unsqueeze(0) + t_mask = torch.from_numpy(mask).to(torch.float32).unsqueeze(0) + t_bound_coord = torch.from_numpy(np.array(bound_coord)).unsqueeze(0).float() + t_crop_coord = torch.from_numpy(np.array(crop_coord)).unsqueeze(0).float() + t_start_coord = torch.from_numpy(np.array(start_end_coord)).unsqueeze(0).float() + return (t_img, t_mask, t_bound_coord, t_crop_coord, t_start_coord) + +def expand_data(img, mask, bound_coord, crop_coord, start_end_coord, num_expand): + size = 3*config['crop_size'] + out_imgs = img.expand(num_expand,3,size,size) + out_masks = mask.expand(num_expand,size,size) + out_bounds = bound_coord.expand(num_expand,2,2) + out_crops = crop_coord.expand(num_expand,2,2) + out_starts_ends = start_end_coord.expand(num_expand,2,2) + return out_imgs, out_masks, out_bounds, out_crops, out_starts_ends + + +def get_centers_helper(rmins, rmaxs, cmins, cmaxs): + rcenter = (rmins + (rmaxs + 1 - rmins)/2).int() + ccenter = (cmins + (cmaxs + 1 - cmins)/2).int() + return torch.stack([rcenter, ccenter], dim=1) + +def get_centers(bounds): + rmins = bounds[:,0,0] + rmaxs = bounds[:,0,1] + cmins = bounds[:,1,0] + cmaxs = bounds[:,1,1] + return get_centers_helper(rmins, rmaxs, cmins, cmaxs) + +def get_zooms(bounds, crops): + def get_hw(coords): + rmins = coords[:,0,0] + rmaxs = coords[:,0,1] + cmins = coords[:,1,0] + cmaxs = coords[:,1,1] + return (rmaxs - rmins), (cmaxs - cmins) + h_b,w_b = get_hw(bounds) + h_c,w_c = get_hw(crops) + r_h = h_b/h_c + r_w = w_b/w_c + #assert r_h.max() < 1.0001 + #assert r_w.max() < 1.0001 + return torch.max(r_h, r_w) + + +class STECeil(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input.ceil() + + @staticmethod + def backward(ctx, grad_output): + return F.hardtanh(grad_output) + +ste_ceil = STECeil.apply + +class STEFloor(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + return input.floor() + + @staticmethod + def backward(ctx, grad_output): + return F.hardtanh(grad_output) + +ste_floor = STEFloor.apply + +def get_bounds_t(coords, do_check=True, quantize=False): + rmins = coords[:,0,0] + rmaxs = coords[:,0,1] + cmins = coords[:,1,0] + cmaxs = coords[:,1,1] + if do_check: + assert torch.allclose((rmaxs - rmins), (cmaxs - cmins)) + + if quantize: + rmins = ste_ceil(rmins) + rmaxs = ste_floor(rmaxs) + cmins = ste_ceil(cmins) + cmaxs = ste_floor(cmaxs) + return rmins, rmaxs, cmins, cmaxs + + +def get_zoom_bounds(crops, start_coords): + h_start, h_end, w_start, w_end = get_bounds_t(start_coords, quantize=True, do_check=False) + + rmins, rmaxs, cmins, cmaxs = get_bounds_t(crops) + centers = get_centers_helper(rmins, rmaxs, cmins, cmaxs) + scale_left = (centers[:,0] - rmins)/(centers[:,0] - h_start) + scale_right = (rmaxs - centers[:,0])/(h_end - centers[:,0]) + scale_top = (centers[:,1] - cmins)/(centers[:,1] - w_start) + scale_down = (cmaxs - centers[:,1])/(w_end - centers[:,1]) + stacked = torch.stack([scale_left, scale_right, scale_top, scale_down], dim=1) + return torch.max(stacked, dim=1)[0], centers + +def zoom_coords(coords, scales, centers): + out_coords = torch.zeros_like(coords) + out_coords[:,:,0] = (centers - scales * (centers - coords[:,:,0])) + out_coords[:,:,1] = (centers + scales * (coords[:,:,1] - centers)) + return out_coords + + +def find_height_max_enlosing_rect(crops, start_coords): + h_start, h_end, w_start, w_end = get_bounds_t(start_coords, quantize=True, do_check=False) + + rmins, rmaxs, cmins, cmaxs = get_bounds_t(crops, do_check=False) + + space_below = rmins - h_start + space_above = h_end - (rmaxs) + space_vertical = torch.min(space_below, space_above) + space_left = cmins - w_start + space_right = w_end - (cmaxs) + space_horizontal = torch.min(space_left, space_right) + + inner_heights =(rmaxs - rmins) + outer_heights = (2*space_vertical + 1) + inner_heights + + inner_widths = (cmaxs - cmins) + outer_widths = (2*space_horizontal + 1) + inner_widths + + # inner_heights = (h_end +1 - h_start) - 2*space_vertical + # inner_widths = (w_end +1 - w_start) - 2*space_horizontal + + return outer_heights, outer_widths#, inner_heights, inner_widths + +def find_height_inclosing_rect_same_center(crops, bounds): + h_start, h_end, w_start, w_end = get_bounds_t(crops, quantize=True, do_check=False) + + rmins, rmaxs, cmins, cmaxs = get_bounds_t(bounds, do_check=False) + + def get_length(i_start,i_end,o_start,o_end): + left = i_start - o_start + right = o_end - i_end + return (o_end + 1 - o_start) - 2*torch.min(left, right) + + # space_below = rmins - h_start + # space_above = h_end - (rmaxs) + # space_vertical = torch.max(space_below, space_above) + # space_left = cmins - w_start + # space_right = w_end - (cmaxs) + # space_horizontal = torch.max(space_left, space_right) + + # inner_heights =(rmaxs - rmins) + # heights = (2*space_vertical + 1) + inner_heights + + # inner_widths = (cmaxs - cmins) + # widths = (2*space_horizontal + 1) + inner_widths + + heights = get_length(rmins,rmaxs,h_start,h_end) + widths = get_length(cmins,cmaxs,w_start,w_end) + + return heights, widths + +import math +def find_max_angle(bounds, crops, start_coords, resize): + def calc_max_angles(cube_lengths, outer): + space_left = (outer.float() - cube_lengths)/2 + # height is: + # H = C*sin(tetha)*cos(tetha) + # so + # (sin * cos)^-1(H/C) =tetha + max_angle = torch.zeros(outer.shape[0]).to(cube_lengths.device) + #space_left == 0 => max_angle = 0 + + ratio_denom = cube_lengths[space_left != 0].float() + ratio = space_left[space_left != 0] / ratio_denom + + calc_max_angle = torch.empty(ratio.shape[0]).to(cube_lengths.device) + calc_max_angle[ratio >= 1/2.] = 2*np.pi + calc_max_angle[ratio < 1/2.] = (1/2.*torch.asin(2*ratio[ratio < 1/2.])).abs() + + max_angle[space_left != 0] = calc_max_angle + + return max_angle + + def calc_max_angle_2(a,b,A): + max_angle = torch.zeros(A.shape[0], device=A.device) + denom = torch.sqrt(((a/2)**2) + ((b/2)**2)) + temp1 = (A/2.0)/denom + temp2 = (a/2.0)/denom + max_angle[temp1 >= 1.0] = 2*np.pi + max_angle[temp1 < 1.0] = torch.asin(temp1[temp1 < 1.0])-torch.asin(temp2[temp1 < 1.0]) + return max_angle + + if resize: + rmins, rmaxs, cmins, cmaxs = get_bounds_t(crops) + inner_heights =(rmaxs - rmins) + inner_widths = (cmaxs - cmins) + assert torch.all(inner_heights == inner_widths) + enclosing_heights, enclosing_widths = find_height_max_enlosing_rect(crops, start_coords) + + max_angles_1 = calc_max_angles(inner_heights, enclosing_heights) + max_angles_2 = calc_max_angles(inner_widths, enclosing_widths) + return torch.min(max_angles_1, max_angles_2) + else: + def find_non_resize_angle(inner, outer): + #inner_heights, inner_widths = find_height_inclosing_rect_same_center(outer, inner) + enclosing_heights, enclosing_widths = find_height_max_enlosing_rect(inner, outer) + rmins, rmaxs, cmins, cmaxs = get_bounds_t(inner, do_check=False) + inner_heights = (rmaxs + 1) - rmins + inner_widths = (cmaxs + 1) - cmins + max_angles_1 = calc_max_angle_2(inner_heights, inner_widths, enclosing_heights) + max_angles_2 = calc_max_angle_2(inner_widths, inner_heights, enclosing_widths) + return torch.min(max_angles_1, max_angles_2) + + inner_angle = find_non_resize_angle(bounds, crops) + + outer_angle = find_non_resize_angle(crops, start_coords) + + # #TODO: our picture_bounds (start-coords) could get inside our image! + # #rmins, rmaxs, cmins, cmaxs = get_bounds_t(bounds, do_check=False) + # inner_heights, inner_widths = find_height_inclosing_rect_same_center(crops, bounds) + # # inner_heights =(rmaxs + 1) - rmins + # # inner_widths = (cmaxs + 1) - cmins + # enclosing_heights, enclosing_widths = find_height_max_enlosing_rect(bounds, crops) + + # # max_angles_1 = calc_max_angles(inner_heights, enclosing_heights) + # # max_angles_2 = calc_max_angles(inner_widths, enclosing_widths) + # max_angles_1 = calc_max_angle_2(inner_heights, inner_widths, enclosing_heights) + # max_angles_2 = calc_max_angle_2(inner_widths, inner_heights, enclosing_widths) + # inner_angle = torch.min(max_angles_1, max_angles_2) + + + #outer_angle = find_max_angle(bounds, crops, start_coords, True) + return torch.min(inner_angle, outer_angle) + +def rotate(imgs, masks, bounds, crops, start_coords, angles, max_angle, verbose=True, resize=False, disable_check=False): + if not disable_check: + with torch.no_grad(): + max_possible_angles = find_max_angle(bounds, crops, start_coords, resize) + temp = angles[angles.abs()>max_possible_angles] + rhs = temp.sign()*max_possible_angles[angles.abs()>max_possible_angles] + angles[angles.abs()>max_possible_angles] = rhs + temp = angles[angles.abs()>max_angle] + angles[angles.abs()>max_angle] = temp.sign()*max_angle + + rmins, rmaxs, cmins, cmaxs = get_bounds_t(crops) + centers = get_centers_helper(rmins, rmaxs, cmins, cmaxs).float() + centers = torch.flip(centers, dims=(-1,)) + + + real_angles = angles/np.pi * 180. + if verbose: + print(f"rotating with {real_angles}") + + real_angles = real_angles.expand(imgs.shape[0]) + centers = centers.expand(centers.shape[0], -1) + import kornia + mode: str = 'bilinear' + padding_mode: str = 'zeros' + align_corners: bool = True + rotation_matrix: torch.Tensor = kornia.geometry.transform.affwarp._compute_rotation_matrix(real_angles, centers) + + def rotate_coords(coords, mat): + rmins, rmaxs, cmins, cmaxs = get_bounds_t(coords, do_check=False) + + # dim [B,2] + upper_left = torch.stack((rmins, cmins), dim=1) + upper_right = torch.stack((rmins, cmaxs), dim=1) + down_left = torch.stack((rmaxs, cmins), dim=1) + down_right = torch.stack((rmaxs, cmaxs), dim=1) + + # dim [B,4,2] + rectangle_dims = torch.stack((upper_left, upper_right, down_left, down_right),dim=-2) + all_dims = torch.flip(rectangle_dims, dims=(-1,)) + # dim [B,4,3] + all_dims_affine = torch.concat([all_dims, torch.ones((all_dims.shape[0],4,1), device=all_dims.device)], dim=-1) + + # [B,2,3], [B,4,3] -> [B,4,2] + rotated = torch.einsum("boi, bri -> bro", mat, all_dims_affine) + rotated_flipped = torch.flip(rotated, dims=(-1,)) + return rotated_flipped + + def rotate_coords_and_fit_rectangle(coords, mat): + rotated_flipped = rotate_coords(coords, mat) + start_h = rotated_flipped[:,:,0].min(dim=1)[0] + end_h = rotated_flipped[:,:,0].max(dim=1)[0] + bounds_h = torch.stack([start_h, end_h], dim=-1) + start_w = rotated_flipped[:,:,1].min(dim=1)[0] + end_w = rotated_flipped[:,:,1].max(dim=1)[0] + bounds_w = torch.stack([start_w, end_w], dim=-1) + + return torch.stack([bounds_h, bounds_w], dim=-2) + + if resize: + bounds_crops = rotate_coords_and_fit_rectangle(crops, rotation_matrix[..., :2, :3]) + size_h = bounds_crops[:,0,1] - bounds_crops[:,0,0] + size_w = bounds_crops[:,1,1] - bounds_crops[:,1,0] + rotated_size = torch.stack((size_h, size_w), dim=1).max(dim=1)[0] + + scale_factor = float(config['crop_size'])/rotated_size + if len(scale_factor.shape) == 1: + scale_factor = scale_factor.unsqueeze(1).repeat(1, 2) + #scale_factor = scale_factor.repeat(1, 2) + scaling_matrix: torch.Tensor = kornia.geometry.transform.affwarp._compute_scaling_matrix(scale_factor, centers) + ones = torch.tensor([0,0,1], device=rotation_matrix.device).view(1,1,3).expand((scaling_matrix.shape[0],1,3)) + temp = torch.concat([rotation_matrix[..., :2, :3], ones], dim=1) + operation = torch.bmm(scaling_matrix[..., :2, :3], temp) + + if True: + bounds_crops = rotate_coords_and_fit_rectangle(crops, operation) + size_h = bounds_crops[:,0,1] - bounds_crops[:,0,0] + size_w = bounds_crops[:,1,1] - bounds_crops[:,1,0] + rotated_size = torch.stack((size_h, size_w), dim=1).max(dim=1)[0] + assert torch.allclose(rotated_size, torch.tensor(float(config['crop_size']))) + + rotate_imgs = t.affine(imgs, operation, mode, padding_mode, align_corners) + rotate_masks = t.affine(masks.unsqueeze(1), operation, mode, padding_mode, align_corners).squeeze(1) + + rotated_bounds = rotate_coords_and_fit_rectangle(bounds, operation) + #rotated_start_coords = transform_coords(start_coords) + + return rotate_imgs, rotate_masks, rotated_bounds, crops, None + else: + #plot_debug_2(rotation_matrix[..., :2, :3]) + bounds_bounds = rotate_coords_and_fit_rectangle(crops, rotation_matrix[..., :2, :3]) + rotate_imgs = t.affine(imgs, rotation_matrix[..., :2, :3], mode, padding_mode, align_corners) + rotate_masks = t.affine(masks.unsqueeze(1), rotation_matrix[..., :2, :3], mode, padding_mode, align_corners).squeeze(1) + return rotate_imgs, rotate_masks, bounds_bounds, crops, None + + +def random_rotation(imgs, masks, bounds, crops, start_coords, max_angle=45./180.*np.pi): + angles = find_max_angle(crops, start_coords) + random = -angles + (2*angles * torch.rand(imgs.shape[0])) + return rotate(imgs, masks, bounds, crops, start_coords, random, max_angle), random + +def calculate_translate_bounds(crops, bounds, start_coords): + i_rmins, i_rmaxs, i_cmins, i_cmaxs = get_bounds_t(bounds, do_check=False) + + rmins, rmaxs, cmins, cmaxs = get_bounds_t(crops) + + space_below = i_rmins - rmins + space_above = (rmaxs - 1) - (i_rmaxs - 1) + space_left = i_cmins - cmins + space_right = (cmaxs - 1) - (i_cmaxs - 1) + + h_start, h_end, w_start, w_end = get_bounds_t(start_coords, quantize=True, do_check=False) + + image_space_above = h_end - rmaxs + image_space_below = rmins - h_start + image_space_left = cmins - w_start + image_space_right = w_end - cmaxs + + + + move_max_below = torch.min(image_space_above, space_below) + move_max_above = torch.min(image_space_below, space_above) + move_max_left = torch.min(image_space_left, space_right) + move_max_right = torch.min(image_space_right, space_left) + + #todo we might shift too much and + + return move_max_below, move_max_above, move_max_left, move_max_right + +def translate_xy(imgs, masks, bounds, crops, start_coords, trans_x, trans_y, verbose=True): + with torch.no_grad(): + t_bounds = calculate_translate_bounds(crops, bounds, start_coords) + move_max_below, move_max_above, move_max_left, move_max_right = t_bounds + dtype = trans_x.dtype + + move_max_below = move_max_below.to(dtype) + move_max_above = move_max_above.to(dtype) + move_max_left = move_max_left.to(dtype) + move_max_right = move_max_right.to(dtype) + + trans_x[trans_x < -move_max_above] = -move_max_above[trans_x < -move_max_above] + trans_x[trans_x > move_max_below] = move_max_below[trans_x > move_max_below] + + # trans_y[trans_y < -move_max_right] = -move_max_right[trans_y < -move_max_right] + # trans_y[trans_y > move_max_left] = move_max_left[trans_y > move_max_left] + + # trans_x[trans_x < -move_max_below] = -move_max_below[trans_x < -move_max_below] + # trans_x[trans_x > move_max_above] = move_max_above[trans_x > move_max_above] + + trans_y[trans_y < -move_max_left] = -move_max_left[trans_y < -move_max_left] + trans_y[trans_y > move_max_right] = move_max_right[trans_y > move_max_right] + + if verbose: + print(f"real translating with {trans_x, trans_y}") + t_vecs = torch.stack([trans_y, trans_x], dim=1) + t_vecs = (-1)*t_vecs + do_it = lambda i: t.translate(i, t_vecs) + out_imgs = do_it(imgs) + out_masks = do_it(masks.unsqueeze(1)).squeeze(1) + + len_available_b = 3*config['crop_size'] + + out_bounds = bounds.clone() + out_bounds[:,0] = (out_bounds[:,0] - trans_x.unsqueeze(1)).clamp(0,len_available_b) + out_bounds[:,1] = (out_bounds[:,1] - trans_y.unsqueeze(1)).clamp(0,len_available_b) + + start_ct = start_coords.clone() + start_ct[:,0] = (start_ct[:,0] - trans_x.unsqueeze(1)).clamp(0,len_available_b) + start_ct[:,1] = (start_ct[:,1] - trans_y.unsqueeze(1)).clamp(0,len_available_b) + + return out_imgs, out_masks, out_bounds, crops, start_ct + + #out_crops = do_it(crops) + #return imgs, masks, bounds, out_crops, start_coords + +def random_translate(imgs, masks, bounds, crops, start_coords): + t_bounds = calculate_translate_bounds(crops, bounds, start_coords) + move_max_below, move_max_above, move_max_left, move_max_right = t_bounds + print(move_max_below, move_max_above, move_max_left, move_max_right) + t_x = -move_max_below + (move_max_below + move_max_above) * torch.rand(imgs.shape[0]) + t_y = -move_max_left + (move_max_left + move_max_right) * torch.rand(imgs.shape[0]) + print(f"translating with {t_x, t_y}") + return translate_xy(imgs, masks, bounds, crops, start_coords, t_x, t_y), (t_x,t_y) + +def crop_batches(imgs: torch.Tensor, masks: torch.Tensor, bounds: torch.Tensor, + crops: torch.Tensor, start_coords: torch.Tensor): + start_y = crops[:,0,0] + start_x = crops[:,1,0] + trans_y = config['crop_size']-start_y + trans_x = config['crop_size']-start_x + t_vecs = torch.stack([trans_x, trans_y], dim=1) + do_it = lambda i: t.translate(i, t_vecs) + if not (t_vecs == 0).all().item(): + out_imgs = do_it(imgs) + out_masks = do_it(masks.unsqueeze(1)).squeeze(1) + else: + out_imgs = imgs + out_masks = masks + start = config['crop_size'] + end = config['crop_size'] + config['crop_size'] + cropped_imgs = out_imgs[:,:,start:(end),start:(end)] + cropped_masks = out_masks[:,start:(end),start:(end)] + max_m = cropped_masks.amax(dim=(1, 2)) + + bounds_res = bounds.clone() + bounds_res[:,0] = bounds_res[:,0] - start_y.unsqueeze(-1) + bounds_res[:,1] = bounds_res[:,1] - start_x.unsqueeze(-1) + + crops_res = torch.zeros_like(bounds_res) + crops_res[:,:,1] = config['crop_size'] + + start_coords = torch.zeros_like(bounds_res) + start_coords[:,:,1] = config['crop_size'] + + return cropped_imgs, cropped_masks, bounds_res, crops_res, start_coords + +def rescale_cropped(imgs: torch.Tensor, masks: torch.Tensor, bounds: torch.Tensor, + crops: torch.Tensor, start_coords: torch.Tensor): + if config['crop_size'] != config['target_size']: + scale = float(config['target_size'])/float(config['crop_size']) + scale_tensor = torch.tensor([scale], device=imgs.device).expand(imgs.shape[0]).unsqueeze(-1) + scaled_cropped = t.scale(imgs, scale_tensor) + scaled_cropped_masks = t.scale(masks.unsqueeze(1), scale_tensor).squeeze(1) + start = int((config['crop_size'] != config['target_size'])/2) + end = start + config['target_size'] + s_c = scaled_cropped[:,:,start:(end),start:(end)] + s_m = scaled_cropped_masks[:,start:(end),start:(end)] + + rmins, rmaxs, cmins, cmaxs = get_bounds_t(crops) + centers = get_centers_helper(rmins, rmaxs, cmins, cmaxs) + + def rescale_coords(unscaled_coords): + coords = zoom_coords(unscaled_coords, scale_tensor, centers) + coords[:,0] = coords[:,0] - start + coords[:,1] = coords[:,1] - start + return coords + + r_bounds = rescale_coords(bounds) + r_crops = rescale_coords(crops) + r_start_coords = rescale_coords(start_coords) + + assert torch.allclose(r_crops[:,:,0], torch.zeros_like(r_crops)) + assert torch.allclose(r_crops[:,:,1], torch.zeros_like(r_crops) + config['target_size']) + + + return s_c, s_m, r_bounds, r_crops, r_start_coords + else: + return imgs, masks, bounds, crops, start_coords \ No newline at end of file diff --git a/shifthappens/tasks/lost_in_translation/affine_transformations/affine_linspace.py b/shifthappens/tasks/lost_in_translation/affine_transformations/affine_linspace.py new file mode 100644 index 00000000..79cf2cd5 --- /dev/null +++ b/shifthappens/tasks/lost_in_translation/affine_transformations/affine_linspace.py @@ -0,0 +1,236 @@ +import tqdm +import torch +import shifthappens.tasks.lost_in_translation.affine_transformations.affine as a +import numpy as np +import math +import gc +import random + +def eval_batched_numpy(data, model, eval_device, batch_size = 1000): + if eval_device == data.device and data.shape[0] <= batch_size: + total_num = data.shape[0] + results = [] + with torch.no_grad(): + res = model(data).detach().cpu().numpy() + return res + else: + total_num = data.shape[0] + results = [] + dataset = torch.utils.data.TensorDataset(data) + loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=True) + with torch.no_grad(): + for i, [data_slice] in enumerate(tqdm.tqdm(loader, leave=False, desc="eval model")): + data_slice = data_slice.to(eval_device) + res = model(data_slice).detach().cpu().numpy() + # if (i // 10) == 0: + # gc.collect() + results.append(res) + return np.concatenate(results, axis=0) + +def np_softmax(x): + max = np.max(x,axis=1,keepdims=True) + e_x = np.exp(x - max) + sum = np.sum(e_x,axis=1,keepdims=True) + return e_x / sum + +def calculate_zoom_for_target(target, bounds, crops, start_coords): + zooms_bounds, _ = a.get_zoom_bounds(crops, start_coords) + zooms_current = a.get_zooms(bounds, crops) + zooms_required = target / zooms_current + zoom = torch.maximum(zooms_required, zooms_bounds).clamp(min=0.0, max=1.0) + return zoom + +def gather_masks_statistics(mask, bounds, crops): + center = a.get_centers(bounds) + center_m = a.get_centers(crops) + center[:,0] = center[:,0] - center_m[:,0] + center[:,1] = center[:,1] - center_m[:,1] + + zooms = a.get_zooms(bounds, crops) + + return {'center': center.numpy(), 'zoom': zooms.numpy()} + + # super slow currently + # occupancy = a.calc_mask_occupancy(mask) + + # return {'center': center.numpy(), 'zoom': zooms.numpy(), 'occupancy': occupancy} + +def rotation_linspace(model, model_name, data, eval_device, batch_size_model, batch_size_rotation, resolution, idx_fun=lambda x:x, do_resize=True, save_dir=None): + + if save_dir is not None: + subdir = str(random.randint(1, 99999)) + from pathlib import Path + p = Path(save_dir) / subdir + p.mkdir(parents=False, exist_ok=True) + exp_dir = str(p) + else: + exp_dir = None + + def do_loop(loop_i, model, model_name, eval_device, batch_size_model, resolution, idx_fun, do_resize, results_rotation, d): + + datapoint_zoom, cat, elem = d + imgs, masks, bounds, crops, start_coords = tuple(map(lambda x: x.to(eval_device), datapoint_zoom)) + + target_zoom = 1./math.sqrt(2) + + zoom = calculate_zoom_for_target(target_zoom, bounds, crops, start_coords) + + max_zoom = 0.0 + zoomed = a.do_zoom(imgs, masks, bounds, crops, start_coords, zoom, max_zoom, verbose=False) + imgs, masks, bounds, crops, start_coords = zoomed + + max_angle = 45./180.*np.pi + angles = torch.minimum(a.find_max_angle(bounds, crops, start_coords, resize=do_resize), torch.tensor([max_angle], device=eval_device)) + steps = torch.linspace(0,resolution,resolution, dtype=torch.float32, device=eval_device)/resolution + angles_steps = -angles + (2*angles * steps) + + num_rotation_l = math.ceil(angles_steps.shape[0] / batch_size_rotation) + iterator = range(num_rotation_l) + if num_rotation_l > 10: + iterator = tqdm.tqdm(iterator, leave=False, desc='gen_data') + res_rotation = [] + res_rotation_m = [] + res_rotation_b = [] + res_rotation_c = [] + for i in iterator: + angles_slice = angles_steps[(batch_size_rotation*i):(batch_size_rotation*i+batch_size_rotation)] + datas = a.expand_data(imgs, masks, bounds, crops, start_coords, angles_slice.shape[0]) + r_imgs, r_masks, r_bounds, r_crops, r_start_coords = a.rescale_cropped(*a.crop_batches(*a.rotate(*datas, angles_slice, max_angle, verbose=False))) + res_rotation.append(r_imgs.cpu()) + res_rotation_m.append(r_masks.cpu()) + res_rotation_b.append(r_bounds.cpu()) + res_rotation_c.append(r_crops.cpu()) + del datas + + r_imgs = torch.cat(res_rotation, 0) + r_masks = torch.cat(res_rotation_m, 0) + r_bounds = torch.cat(res_rotation_b, 0) + r_crops = torch.cat(res_rotation_c, 0) + #rotated = a.rotate(*datas, angles_steps, max_angle, verbose=False) + + #cropped_imgs, cropped_masks = a.crop_batches(*rotated) + #rescaled_imgs_rotated, rescaled_masks = a.rescale_cropped(cropped_imgs, cropped_masks) + + + if eval_device.type == 'cuda': + with torch.no_grad(): + pred_rotated = eval_batched_numpy(r_imgs, model, eval_device, batch_size=batch_size_model) + softmaxed_rotated = np_softmax(pred_rotated) + else: + #a.plot_debug_random_pytorch(*res) + pred_rotated = eval_batched_numpy(r_imgs, model, eval_device) + softmaxed_rotated = np_softmax(pred_rotated) + + angles_cpu = angles_steps.cpu() + + if save_dir is not None: + sample_p = save_sample_images(r_imgs, softmaxed_rotated, angles_cpu, cat, exp_dir, loop_i) + else: + sample_p = None + + res = { + "model": model_name, + "data": idx_fun(elem), + "cat": cat, + "params": angles_cpu.numpy(), + "results" : softmaxed_rotated, + "masks_stats": gather_masks_statistics(r_masks, r_bounds, r_crops), + "sample_params": sample_p, + "sample_idx": loop_i, + "exp_dir":exp_dir + } + + results_rotation.append(res) + + results_rotation = [] + + for i,d in enumerate(tqdm.tqdm(data, desc=f"{model_name}:rotation")): + do_loop(i,model, model_name, eval_device, batch_size_model, resolution, idx_fun, do_resize, results_rotation, d) + gc.collect() + + return results_rotation + +def save_sample_images(images, out, params, label: int, dir, idx, model_check=None): + if type(images) is list: + def get_p(i,j,idx): + if isinstance(params, torch.Tensor): + return params[i][j][idx] + else: + return tuple(p[i][j][idx] for p in params) + + miss_classes = [] + for i in range(len(out)): + for j in range(len(out[i])): + classification: np.ndarray = out[i][j].argmax(1) + miss_class = np.where(classification != label)[0] + if miss_class.shape[0] != 0: + w_sample_idx_idx = random.choice(range(miss_class.shape[0])) + w_sample_idx = miss_class[w_sample_idx_idx] + miss_classes.append((i,j,w_sample_idx)) + if len(miss_classes) == 0: + w_sample_idx = None + wrong_params = None + else: + w_sample_idx = random.choice(miss_classes) + (i,j,w_arr_idx) = w_sample_idx + wrong_sample = images[i][j][w_arr_idx].cpu().permute(1,2,0).numpy() + wrong_params = get_p(i,j,w_arr_idx) + assert label != out[i][j].argmax(1)[w_arr_idx] + with open(f'{dir}/{idx}_wrong.npy', 'wb') as f_w: + np.save(f_w, wrong_sample) + if model_check is not None: + test = np.load(f'{dir}/{idx}_wrong.npy') + m_d = next(model_check.parameters()).device + test_torch = torch.from_numpy(test).permute(-1,0,1).unsqueeze(0).to(m_d) + res = model_check(test_torch) + if label == res.argmax(): + import debugpy + + # 5678 is the default attach port in the VS Code debug configurations. Unless a host and port are specified, host defaults to 127.0.0.1 + debugpy.listen(5678) + print("Waiting for debugger attach") + debugpy.wait_for_client() + debugpy.breakpoint() + print('break on this line') + assert label != res.argmax() + s_i = random.choice(range(len(images))) + s_j = random.choice(range(len(images[s_i]))) + sample_arr_idx = random.choice(range(images[s_i][s_j].shape[0])) + sample_idx = (s_i, s_j, sample_arr_idx) + sample = images[s_i][s_j][sample_arr_idx].cpu().permute(1,2,0).numpy() + sample_params = get_p(s_i, s_j, sample_arr_idx) + with open(f'{dir}/{idx}.npy', 'wb') as f_s: + np.save(f_s, sample) + else: + def get_p(idx): + if isinstance(params, torch.Tensor): + return params[idx].cpu().numpy() + else: + return tuple(p[idx].cpu().numpy() for p in params) + classification = out.argmax(1) + miss_class = np.where(classification != label)[0] + if miss_class.shape[0] == 0: + w_sample_idx = None + wrong_params = None + else: + w_sample_idx_idx = random.choice(range(miss_class.shape[0])) + w_sample_idx = miss_class[w_sample_idx_idx] + wrong_sample = images[w_sample_idx].cpu().permute(1,2,0).numpy() + wrong_params = get_p(w_sample_idx) + assert label != classification[w_sample_idx] + with open(f'{dir}/{idx}_wrong.npy', 'wb') as f_w: + np.save(f_w, wrong_sample) + if model_check is not None: + test = np.load(f'{dir}/{idx}_wrong.npy') + m_d = next(model_check.parameters()).device + test_torch = torch.from_numpy(test).permute(-1,0,1).unsqueeze(0).to(m_d) + res = model_check(test_torch) + assert label != res.argmax() + + sample_idx = random.choice(range(images.shape[0])) + sample = images[sample_idx].cpu().permute(1,2,0).numpy() + sample_params = get_p(sample_idx) + with open(f'{dir}/{idx}.npy', 'wb') as f_s: + np.save(f_s, sample) + + return (sample_idx, sample_params), (w_sample_idx, wrong_params) \ No newline at end of file diff --git a/shifthappens/tasks/lost_in_translation/affine_transformations/affine_linspace_adaptive.py b/shifthappens/tasks/lost_in_translation/affine_transformations/affine_linspace_adaptive.py new file mode 100644 index 00000000..71c0aa30 --- /dev/null +++ b/shifthappens/tasks/lost_in_translation/affine_transformations/affine_linspace_adaptive.py @@ -0,0 +1,418 @@ +import tqdm +import torch +import shifthappens.tasks.lost_in_translation.affine_transformations.affine as a +import shifthappens.tasks.lost_in_translation.affine_transformations.affine_linspace as a_s +import numpy as np +import math +import gc +import random +import tqdm +import gc +import collections +from skimage.feature.peak import peak_local_max + +def find_mins(res_correct, res_incorrect, past_size_x, past_size_y, nums, find_min_correct=True): + if find_min_correct: + mat_t: np.ndarray = res_correct.reshape(past_size_x, past_size_y)#.cpu().numpy() + mat = (-1)*mat_t + np.max(mat_t) + indices_unraveled = peak_local_max(mat, min_distance=3, exclude_border=False, num_peaks=nums) + else: + mat_correct: np.ndarray = res_correct.reshape(past_size_x, past_size_y) + mat_res_incorrect: np.ndarray = res_incorrect.reshape(past_size_x, past_size_y) + temp = mat_correct - mat_res_incorrect + mat = (-1)*temp + np.max(temp) + indices_unraveled = peak_local_max(mat, min_distance=3, exclude_border=False, num_peaks=nums) + indices = np.ravel_multi_index((indices_unraveled[:,0], indices_unraveled[:,1]), mat.shape) + return indices[:nums], (-1)*np.take(np.reshape(mat, -1), indices) + +def sample_adaptive(trans_xs: list[torch.Tensor], trans_ys: list[torch.Tensor], softmaxes_s: list[np.ndarray], + num_points, resolution, cat, bounds, size_recursive, eval_device, batch_size_translation, data, model, + batch_size_model, past_t_x, past_t_y, find_min_correct=True, adapt_resolution=False): + results = {} + for idx in range(len(trans_xs)): + softmaxes = softmaxes_s[idx] + res_incorrect = np.max(np.delete(softmaxes, cat, axis=1), axis=1) + res_correct = softmaxes[:, cat] + past_size_x = past_t_x[idx].shape[0] + past_size_y = past_t_y[idx].shape[0] + idx_mins, min_vals = find_mins(res_correct, res_incorrect, past_size_x, past_size_y, num_points, find_min_correct=find_min_correct) + for idx2 in range(idx_mins.shape[0]): + results[min_vals[idx2]] = (idx, idx_mins[idx2]) + + if len(results) == 0: + for idx in range(len(trans_xs)): + softmaxes = softmaxes_s[idx] + res_correct = softmaxes[:, cat] + + idx_min = np.argmin(res_correct) + results[res_correct[idx_min].item()] = (idx, idx_min) + + keys_sorted = list(sorted(results.keys())) + + move_max_below, move_max_above, move_max_left, move_max_right = bounds + height = move_max_above + move_max_below + if size_recursive < 1.0: + area_height = size_recursive * height + else: + area_height = min(height.item(), size_recursive) + width = move_max_right + move_max_left + if size_recursive < 1.0: + area_width = size_recursive * width + else: + area_width = min(width.item(), size_recursive) + center = a.get_centers(data[2]) + results_imgs = [] + results_masks = [] + results_bounds = [] + results_crops = [] + results_pred_translated = [] + results_softmaxed_translated = [] + results_coords_x = [] + results_coords_y = [] + results_points = [] + results_t_x = [] + results_t_y = [] + + for key in keys_sorted[:num_points]: + (list_idx, min_idx) = results[key] + x_coord = trans_xs[list_idx][min_idx] + y_coord = trans_ys[list_idx][min_idx] + results_points.append((x_coord.item(), y_coord.item(), list_idx)) + x_start = max(x_coord - (area_height / 2), (-1)*move_max_above) + x_end = min(x_coord + (area_height / 2), move_max_below) + y_start = max(y_coord - (area_width / 2), (-1)*move_max_left) + y_end = min(y_coord + (area_width / 2), move_max_right) + t_x, t_y = calculate_linspace(x_start, x_end, y_start, y_end, resolution, center, eval_device, adapt_resolution=adapt_resolution) + + combinations = torch.cartesian_prod(t_x, t_y) + comb_x = combinations[:,0] + comb_y = combinations[:,1] + assert math.ceil(comb_x.shape[0] / min(batch_size_translation, batch_size_model)) == 1 + + temp = translation_helper(model, comb_x, comb_y, batch_size_translation, batch_size_model, data, eval_device) + r_imgs, r_masks, r_bounds, r_crops, pred_translated, softmaxed_translated = temp + results_imgs.append(r_imgs) + results_masks.append(r_masks) + results_bounds.append(r_bounds) + results_crops.append(r_crops) + results_pred_translated.append(pred_translated) + results_softmaxed_translated.append(softmaxed_translated) + results_coords_x.append(comb_x) + results_coords_y.append(comb_y) + results_t_x.append(t_x) + results_t_y.append(t_y) + return results_imgs, results_masks, results_bounds, results_crops, results_pred_translated, results_softmaxed_translated, results_coords_x, results_coords_y, results_points, results_t_x, results_t_y + + +def translation_helper(model, comb_x, comb_y, batch_size_translation, batch_size_model, data, eval_device, only_softmaxes=False): + num_trans_l = math.ceil(comb_x.shape[0] / batch_size_translation) + iterator = range(num_trans_l) + res_translated = [] + res_translated_m = [] + res_translated_b = [] + res_translated_c = [] + + overall_steps = math.ceil(comb_x.shape[0] / min(batch_size_translation, batch_size_model)) + + if overall_steps > 1: + iterator = tqdm.tqdm(iterator, leave=False, desc='gen_data') + + assert not only_softmaxes + + def do_inner_loop(i): + comb_x_slice = comb_x[(batch_size_translation*i):(batch_size_translation*i+batch_size_translation)] + comb_y_slice = comb_y[(batch_size_translation*i):(batch_size_translation*i+batch_size_translation)] + datas = a.expand_data(*data, comb_x_slice.shape[0]) + r_imgs, r_masks, r_bounds, r_crops, r_start_coords = a.rescale_cropped(*a.crop_batches(*a.translate_xy(*datas, comb_x_slice, comb_y_slice, verbose=False))) + res_translated.append(r_imgs.cpu()) + res_translated_m.append(r_masks.cpu()) + res_translated_b.append(r_bounds.cpu()) + res_translated_c.append(r_crops.cpu()) + + + for i in iterator: + do_inner_loop(i) + + r_imgs = torch.cat(res_translated, 0) + r_masks = torch.cat(res_translated_m, 0) + r_bounds = torch.cat(res_translated_b, 0) + r_crops = torch.cat(res_translated_c, 0) + + if eval_device.type == 'cuda': + with torch.no_grad(): + pred_translated = a_s.eval_batched_numpy(r_imgs, model, eval_device, batch_size=batch_size_model) + softmaxed_translated = a_s.np_softmax(pred_translated) + else: + pred_translated = a_s.eval_batched_numpy(r_imgs, model, eval_device) + softmaxed_translated = a_s.np_softmax(pred_translated) + + else: + with torch.no_grad(): + datas = a.expand_data(*data, comb_x.shape[0]) + def guard(): + r_imgs, r_masks, r_bounds, r_crops, r_start_coords = a.rescale_cropped(*a.crop_batches(*a.translate_xy(*datas, comb_x, comb_y, verbose=False))) + if only_softmaxes: + return r_imgs, None, None, None + else: + return r_imgs, r_masks.cpu(), r_bounds.cpu(), r_crops.cpu() + r_imgs_cuda, r_masks, r_bounds, r_crops = guard() + if eval_device.type == 'cuda': + pred_translated = a_s.eval_batched_numpy(r_imgs_cuda, model, eval_device, batch_size=batch_size_model) + softmaxed_translated = a_s.np_softmax(pred_translated) + else: + pred_translated = a_s.eval_batched_numpy(r_imgs_cuda, model, eval_device) + softmaxed_translated = a_s.np_softmax(pred_translated) + if only_softmaxes: + r_imgs = None + else: + r_imgs = r_imgs_cuda.cpu() + + + return r_imgs, r_masks, r_bounds, r_crops, pred_translated, softmaxed_translated + +def calculate_linspace(start_x, end_x, start_y, end_y, resolution, center, eval_device, adapt_resolution=False): + if resolution > 10: + steps = torch.linspace(0,resolution,resolution, dtype=torch.float32, device=eval_device)/resolution + + t_x = start_x + (end_x - start_x) * steps + t_y = start_y + (end_y - start_y) * steps + else: + center_x = center[:,0] + center_y = center[:,1] + + def calculate_steps(start, end, center, step_size): + offset = (center % step_size) + #calculation + #start2 = res*((start - offset) // res) + res*math.ceil(((start - offset) % res)) + offset + real_start = start - offset + start_i = torch.div(real_start, step_size, rounding_mode='trunc') + start_i = start_i.int() + real_end = end - offset + end_i = torch.div(real_end, step_size, rounding_mode='floor') + end_i = end_i.int() + start_i = min(start_i, end_i) + leng = end_i - start_i + assert leng >= 0 + if (leng == 0).item(): + if adapt_resolution: + if step_size <= 0.05: + return torch.zeros_like(start_i).float() + start + else: + return calculate_steps(start, end, center, 1/2*step_size) + else: + #next step is too much, maybe last step would be ok (we may skip the first) + last_step = (center % step_size) - step_size + if (last_step >= start and last_step <= end).item(): + return last_step.float() + else: + return torch.zeros_like(start_i).float() + start + else: + steps_i =torch.arange(start=start_i.item(),end=end_i.item()+1e-4,step=1.0, dtype=torch.float, device=eval_device) + steps = (step_size * steps_i.float()) + offset + return steps + t_x = calculate_steps(start_x, end_x, center_x, resolution) + t_y = calculate_steps(start_y, end_y, center_y, resolution) + + return t_x,t_y + +def calc_min_grid(start_x, end_x, start_y, end_y, step_size, lenght, eval_device, model, batch_size_translation, batch_size_model, data, cat, big_step_size): + def calc_start_end(bound_lower,bound_upper): + middle = bound_lower + (bound_upper - bound_lower)/2 + start = max(middle - float(lenght)/2,bound_lower) + end = min(middle + float(lenght)/2,bound_upper) + steps = torch.arange(start=start.item(), end=(end.item() + 1e-4), step=step_size, device=eval_device) + return steps + t_x = calc_start_end(start_x, end_x) + t_x_l = t_x.shape[0] + t_y = calc_start_end(start_y, end_y) + t_y_l = t_y.shape[0] + combinations = torch.cartesian_prod(t_x, t_y) + comb_x = combinations[:,0] + comb_y = combinations[:,1] + out = translation_helper(model, comb_x, comb_y, batch_size_translation, batch_size_model, data, eval_device, only_softmaxes=True) + r_imgs, r_masks, r_bounds, r_crops, pred_translated, softmaxed_translated = out + res_correct = softmaxed_translated[:, cat] + from sklearn.linear_model import LinearRegression + regressor = LinearRegression() + params = (comb_x.cpu().numpy(), comb_y.cpu().numpy()) + X = np.stack(params, axis=1) + regressor.fit(X, res_correct) + pred_raw = regressor.predict(X) + temp = (res_correct - pred_raw).reshape(t_x_l, t_y_l) + indices_unraveled = peak_local_max((-1)*(temp + temp.max()), exclude_border=int(0.4/step_size), min_distance=int(0.4/step_size), num_peaks=4) + if indices_unraveled.shape[0] == 0: + idx_min = np.argmin(res_correct - pred_raw) + else: + idx_min = np.ravel_multi_index((indices_unraveled[:,0], indices_unraveled[:,1]), temp.shape) + idx_min = torch.from_numpy(idx_min).to(eval_device) + x_coords_min = comb_x[idx_min] + y_coords_min = comb_y[idx_min] + x_coord_min = (x_coords_min % big_step_size).mean() + y_coord_min = (y_coords_min % big_step_size).mean() + center = torch.stack((x_coord_min, y_coord_min)).view(1,2) + min_grid_data = { + "params": params, + "results": softmaxed_translated, + "t_x": t_x.cpu().numpy(), + "t_y": t_y.cpu().numpy(), + "point": center.cpu().numpy(), + "points":(x_coords_min.cpu().numpy(), y_coords_min.cpu().numpy()) + } + return center, min_grid_data + + +def translation_linspace_adaptive(model, model_name, data, eval_device, batch_size_model, batch_size_translation, + resolutions, num_points, size_recursive, num_recursion, idx_fun=lambda x:x, target_zoom=0.8, save_dir = None, + find_min_correct=True, early_stopping=False, adapt_grid_to_min=False, step_size_center=0.25, leng_center=4, + period_assumption=1.0, constant_offset=None, adapt_resolution=False, check_saving=False): + + if save_dir is not None: + subdir = str(random.randint(1, 99999)) + from pathlib import Path + p = Path(save_dir) / subdir + p.mkdir(parents=False, exist_ok=True) + exp_dir = str(p) + else: + exp_dir = None + + def do_loop(i, model, model_name, d, results_translation): + datapoint_zoom, cat, idx = d + imgs, masks, bounds, crops, start_coords = tuple(map(lambda x: x.to(eval_device), datapoint_zoom)) + + zoom = a_s.calculate_zoom_for_target(target_zoom, bounds, crops, start_coords) + + max_zoom = 0.0 + zoomed = a.do_zoom(imgs, masks, bounds, crops, start_coords, zoom, max_zoom, verbose=False) + imgs, masks, bounds, crops, start_coords = zoomed + + t_bounds = a.calculate_translate_bounds(crops, bounds, start_coords) + move_max_below, move_max_above, move_max_left, move_max_right = t_bounds + + resolution = resolutions[0] + + if not adapt_grid_to_min: + align_grid = a.get_centers(bounds) + min_grid_data = None + else: + if constant_offset is None: + align_grid, min_grid_data = calc_min_grid(-move_max_above, move_max_below, -move_max_left, move_max_right, + step_size_center, leng_center, eval_device, model, batch_size_translation, batch_size_model, zoomed, cat, period_assumption) + else: + align_grid = torch.from_numpy(constant_offset).view(1,2).float().to(imgs.device) + min_grid_data = None + t_x, t_y = calculate_linspace(-move_max_above, move_max_below, -move_max_left, move_max_right, + resolution, align_grid, eval_device, adapt_resolution=adapt_resolution) + + if len(t_x) == 0 and len(t_y) != 0: + t_x = torch.tensor([0], device=eval_device, dtype=t_y.dtype) + elif len(t_y) == 0 and len(t_x) != 0: + t_y = torch.tensor([0], device=eval_device, dtype=t_x.dtype) + + + if len(t_x) == 0 and len(t_y) == 0: + fst = move_max_above.item() == 0 and move_max_below.item() == 0 + snd = move_max_below.item() == 0 and move_max_left.item() == 0 + assert (fst or snd) and resolution <= 10.0 + print("skipping") + return {"error": "unable to discretize"} + + combinations = torch.cartesian_prod(t_x, t_y) + comb_x = combinations[:,0] + comb_y = combinations[:,1] + + out = translation_helper(model, comb_x, comb_y, batch_size_translation, batch_size_model, zoomed, eval_device) + r_imgs, r_masks, r_bounds, r_crops, pred_translated, softmaxed_translated = out + + params_cpu = (comb_x.cpu(), comb_y.cpu()) + + params_list_x=[[comb_x.cpu().numpy()]] + params_list_y=[[comb_y.cpu().numpy()]] + results_list=[[softmaxed_translated]] + images_list = [[r_imgs]] + masks_list = [[r_masks]] + points_list = [[]] + t_x_list = [[t_x.cpu().numpy()]] + t_y_list = [[t_y.cpu().numpy()]] + + current_results = ([comb_x], [comb_y], [softmaxed_translated]) + if type(size_recursive) is list: + current_size = size_recursive[0] + else: + current_size = 1.0 + + zoomed_data = (imgs, masks, bounds, crops, start_coords) + + to_cpu_np = lambda l: list(map(lambda x: x.cpu().numpy(), l)) + + def should_stop(results_softmaxs): + if not early_stopping: + return False + else: + for softm in results_softmaxs: + res_correct = softm[:, cat] + res_incorrect = np.max(np.delete(softm, cat, axis=1), axis=1) + if np.any(res_incorrect > res_correct): + return True + return False + + past_t_x = [t_x.cpu().numpy()] + past_t_y = [t_y.cpu().numpy()] + if not should_stop([softmaxed_translated]): + for num in range(num_recursion - 1): + if type(size_recursive) is list: + current_size = size_recursive[num] + else: + current_size = current_size * size_recursive + r_i_l, r_m_l, r_b_l, r_c_l, r_p_l, r_s_l, r_x, r_y, r_p, t_x_l, t_y_l = sample_adaptive(*current_results, + num_points, resolutions[num+1], cat, t_bounds, current_size, eval_device, batch_size_translation, zoomed_data, model, + batch_size_model, past_t_x, past_t_y, find_min_correct=find_min_correct, adapt_resolution=adapt_resolution) + current_results = (r_x, r_y, r_s_l) + params_list_x.append(to_cpu_np(r_x)) + params_list_y.append(to_cpu_np(r_y)) + results_list.append(r_s_l) + images_list.append(r_i_l) + masks_list.append(r_m_l) + points_list.append(r_p) + t_x_list.append(to_cpu_np(t_x_l)) + t_y_list.append(to_cpu_np(t_y_l)) + past_t_x = t_x_l + past_t_y = t_y_l + if should_stop(r_s_l): + break + + params_tuple = (params_list_x, params_list_y) + + if save_dir is not None: + model_to_check = None + if check_saving: + model_to_check = model + sample_p = a_s.save_sample_images(images_list, results_list, params_tuple, cat, exp_dir, i, model_check=model_to_check) + else: + sample_p = None + + res = { + "model": model_name, + "data": idx_fun(idx), + "cat": cat, + "params": params_tuple, + "points": points_list, + "results" : results_list, + "masks_stats": a_s.gather_masks_statistics(r_masks, r_bounds, r_crops), + "sample_params": sample_p, + "sample_idx": i, + "t_x": t_x_list, + "t_y": t_y_list, + "min_grid_data": min_grid_data, + "exp_dir": exp_dir, + "batch_size_translation": batch_size_translation + } + + results_translation.append(res) + + results_translation = [] + + for i, d in enumerate(tqdm.tqdm(data, desc=f"{model_name}:translation")): + do_loop(i, model, model_name, d, results_translation) + gc.collect() + + return results_translation \ No newline at end of file diff --git a/shifthappens/tasks/lost_in_translation/affine_transformations/statistics.py b/shifthappens/tasks/lost_in_translation/affine_transformations/statistics.py new file mode 100644 index 00000000..2eab43c5 --- /dev/null +++ b/shifthappens/tasks/lost_in_translation/affine_transformations/statistics.py @@ -0,0 +1,85 @@ +import numpy as np +import math + +label_map = None + +def max_freedom(d): + t_x = d["t_x"][0][0] + t_x_l = t_x[-1]-t_x[0] + t_y = d["t_y"][0][0] + t_y_l = t_y[-1]-t_y[0] + return max(t_x_l,t_y_l) + +def get_correct_class(d): + i = d['data'] + cat_id = label_map[i] + return cat_id + +def has_wrong_classf(d): + idx = get_correct_class(d) + resu = d['results'] + has_wrong = False + if 'min_grid_data' in d and d['min_grid_data'] is not None: + min_data_res = d['min_grid_data']['results'] + res_correct = min_data_res[:, idx] + res_incorrect = np.max(np.delete(min_data_res, idx, axis=1), axis=1) + has_wrong = has_wrong or np.any(res_incorrect > res_correct) + + if type(resu) is list: + for rec in resu: + for point in rec: + res_correct = point[:, idx] + res_incorrect = np.max(np.delete(point, idx, axis=1), axis=1) + has_wrong = has_wrong or np.any(res_incorrect > res_correct) + else: + res_correct = resu[:, idx] + res_incorrect = np.max(np.delete(resu, idx, axis=1), axis=1) + has_wrong = has_wrong or np.any(res_incorrect > res_correct) + return has_wrong + +def adaptive_worst_case(result): + wrong_counter = 0 + for d in result: + has_wrong = has_wrong_classf(d) + if has_wrong: + wrong_counter += 1 + return (len(result) - wrong_counter)/len(result) + +def is_correct(d, mode): + cat_id = get_correct_class(d) + resu = d['results'] + if type(resu) is list: + assert mode == "trans" + t_x = d['params'][0][0][0] + t_y = d['params'][1][0][0] + min_idx = np.argmin((t_x**2)+(t_y**2)) + cat_pred = np.argmax(d['results'][0][0][min_idx]) + return cat_id == cat_pred + else: + if mode == "trans": + t_x,t_y = d['params'] + min_idx = np.argmin((t_x**2)+(t_y**2)) + cat_pred = np.argmax(d['results'][min_idx]) + elif mode == "rotation": + rot = d['params'] + min_idx = np.argmin(np.abs(rot)) + cat_pred = np.argmax(d['results'][min_idx]) + elif mode == "zoom": + zoom = d['params'] + min_idx = np.max(zoom) + cat_pred = np.argmax(d['results'][min_idx]) + return cat_id == cat_pred + +def adaptive_base_case(result, mode): + wrong_counter = 0 + for d in result: + has_wrong = not is_correct(d, mode) + if has_wrong: + wrong_counter += 1 + return (len(result) - wrong_counter)/len(result) + +def radiant_to_degree(data): + return data * (180./math.pi) + +def gt_30(d): + return np.max(radiant_to_degree(d['params'])) >= 30 \ No newline at end of file diff --git a/shifthappens/tasks/lost_in_translation/imagenet_s/__init__.py b/shifthappens/tasks/lost_in_translation/imagenet_s/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/shifthappens/tasks/lost_in_translation/imagenet_s/imagenet_s.py b/shifthappens/tasks/lost_in_translation/imagenet_s/imagenet_s.py new file mode 100644 index 00000000..4e7a56c0 --- /dev/null +++ b/shifthappens/tasks/lost_in_translation/imagenet_s/imagenet_s.py @@ -0,0 +1,217 @@ +from torch.utils import data +from torchvision.datasets import ImageFolder, ImageNet +import torch +import os +from PIL import Image +import numpy as np +import json +import json +import ntpath + +#Adapted from the ImagentS repository: +#https://github.com/UnsupervisedSemanticSegmentation/ImageNet-S + +class ImageNetSEvalDataset(ImageFolder): + def __init__(self, imagenet_root, label_root, name_list, label_path, transform = None, + simple_items=False, use_new_labels=False, prefilter_items=False, + transform_mask_to_img_classes=False): + super().__init__(label_root) + + self.simple_items = simple_items + + self.imagenet = ImageNet(imagenet_root, split='val', transform=None) + self.transform = transform + self.label_lst = [] + self.imagenet_idcs = [] + + self.in_filename_to_idx = {} + self.idx_to_filename = {} + for idx, sample in enumerate(self.imagenet.samples): + imagenet_filename = sample[0] + filename_parts = imagenet_filename.split('/') + filename = os.path.splitext(filename_parts[-1])[0] + self.in_filename_to_idx[filename] = idx + self.idx_to_filename[idx] = filename + + with open(name_list, 'r') as f: + names = f.read().splitlines() + for name in names: + imagenet_filename , imagenet_s_filename = name.split(' ') + self.label_lst.append(os.path.join(label_root, imagenet_s_filename)) + + filename_parts = imagenet_filename.split('/') + filename = os.path.splitext(filename_parts[-1])[0] + imagenet_idx = self.in_filename_to_idx[filename] + self.imagenet_idcs.append(imagenet_idx) + + with open(label_path, 'r') as f: + self.new_labels = json.load(f) + self.new_labels_from_file = {f'ILSVRC2012_val_{(i+1):08d}.JPEG': labels for i, labels in enumerate(self.new_labels)} + self.use_new_labels = use_new_labels + + self.classes_50 = "goldfish, tiger shark, goldfinch, tree frog, kuvasz, red fox, siamese cat, american black bear, ladybug, sulphur butterfly, wood rabbit, hamster, wild boar, gibbon, african elephant, giant panda, airliner, ashcan, ballpoint, beach wagon, boathouse, bullet train, cellular telephone, chest, clog, container ship, digital watch, dining table, golf ball, grand piano, iron, lab coat, mixing bowl, motor scooter, padlock, park bench, purse, streetcar, table lamp, television, toilet seat, umbrella, vase, water bottle, water tower, yawl, street sign, lemon, carbonara, agaric" + self.classes_300 = "tench, goldfish, tiger shark, hammerhead, electric ray, ostrich, goldfinch, house finch, indigo bunting, kite, common newt, axolotl, tree frog, tailed frog, mud turtle, banded gecko, american chameleon, whiptail, african chameleon, komodo dragon, american alligator, triceratops, thunder snake, ringneck snake, king snake, rock python, horned viper, harvestman, scorpion, garden spider, tick, african grey, lorikeet, red-breasted merganser, wallaby, koala, jellyfish, sea anemone, conch, fiddler crab, american lobster, spiny lobster, isopod, bittern, crane, limpkin, bustard, albatross, toy terrier, afghan hound, bluetick, borzoi, irish wolfhound, whippet, ibizan hound, staffordshire bullterrier, border terrier, yorkshire terrier, lakeland terrier, giant schnauzer, standard schnauzer, scotch terrier, lhasa, english setter, clumber, english springer, welsh springer spaniel, kuvasz, kelpie, doberman, miniature pinscher, malamute, pug, leonberg, great pyrenees, samoyed, brabancon griffon, cardigan, coyote, red fox, kit fox, grey fox, persian cat, siamese cat, cougar, lynx, tiger, american black bear, sloth bear, ladybug, leaf beetle, weevil, bee, cicada, leafhopper, damselfly, ringlet, cabbage butterfly, sulphur butterfly, sea cucumber, wood rabbit, hare, hamster, wild boar, hippopotamus, bighorn, ibex, badger, three-toed sloth, orangutan, gibbon, colobus, spider monkey, squirrel monkey, madagascar cat, indian elephant, african elephant, giant panda, barracouta, eel, coho, academic gown, accordion, airliner, ambulance, analog clock, ashcan, backpack, balloon, ballpoint, barbell, barn, bassoon, bath towel, beach wagon, bicycle-built-for-two, binoculars, boathouse, bonnet, bookcase, bow, brass, breastplate, bullet train, cannon, can opener, carpenter's kit, cassette, cellular telephone, chain saw, chest, china cabinet, clog, combination lock, container ship, corkscrew, crate, crock pot, digital watch, dining table, dishwasher, doormat, dutch oven, electric fan, electric locomotive, envelope, file, folding chair, football helmet, freight car, french horn, fur coat, garbage truck, goblet, golf ball, grand piano, half track, hamper, hard disc, harmonica, harvester, hook, horizontal bar, horse cart, iron, jack-o'-lantern, lab coat, ladle, letter opener, liner, mailbox, megalith, military uniform, milk can, mixing bowl, monastery, mortar, mosquito net, motor scooter, mountain bike, mountain tent, mousetrap, necklace, nipple, ocarina, padlock, palace, parallel bars, park bench, pedestal, pencil sharpener, pickelhaube, pillow, planetarium, plastic bag, polaroid camera, pole, pot, purse, quilt, radiator, radio, radio telescope, rain barrel, reflex camera, refrigerator, rifle, rocking chair, rubber eraser, rule, running shoe, sewing machine, shield, shoji, ski, ski mask, slot, soap dispenser, soccer ball, sock, soup bowl, space heater, spider web, spindle, sports car, steel arch bridge, stethoscope, streetcar, submarine, swimming trunks, syringe, table lamp, tank, teddy, television, throne, tile roof, toilet seat, trench coat, trimaran, typewriter keyboard, umbrella, vase, volleyball, wardrobe, warplane, washer, water bottle, water tower, whiskey jug, wig, wine bottle, wok, wreck, yawl, yurt, street sign, traffic light, consomme, ice cream, bagel, cheeseburger, hotdog, mashed potato, spaghetti squash, bell pepper, cardoon, granny smith, strawberry, lemon, carbonara, burrito, cup, coral reef, yellow lady's slipper, buckeye, agaric, gyromitra, earthstar, bolete" + self.classes_919 = "house finch, stupa, agaric, hen-of-the-woods, wild boar, kit fox, desk, beaker, spindle, lipstick, cardoon, ringneck snake, daisy, sturgeon, scorpion, pelican, bustard, rock crab, rock beauty, minivan, menu, thunder snake, zebra, partridge, lacewing, starfish, italian greyhound, marmot, cardigan, plate, ballpoint, chesapeake bay retriever, pirate, potpie, keeshond, dhole, waffle iron, cab, american egret, colobus, radio telescope, gordon setter, mousetrap, overskirt, hamster, wine bottle, bluetick, macaque, bullfrog, junco, tusker, scuba diver, pool table, samoyed, mailbox, purse, monastery, bathtub, window screen, african crocodile, traffic light, tow truck, radio, recreational vehicle, grey whale, crayfish, rottweiler, racer, whistle, pencil box, barometer, cabbage butterfly, sloth bear, rhinoceros beetle, guillotine, rocking chair, sports car, bouvier des flandres, border collie, fiddler crab, slot, go-kart, cocker spaniel, plate rack, common newt, tile roof, marimba, moped, terrapin, oxcart, lionfish, bassinet, rain barrel, american black bear, goose, half track, kite, microphone, shield, mexican hairless, measuring cup, bubble, platypus, saint bernard, police van, vase, lhasa, wardrobe, teapot, hummingbird, revolver, jinrikisha, mailbag, red-breasted merganser, assault rifle, loudspeaker, fig, american lobster, can opener, arctic fox, broccoli, long-horned beetle, television, airship, black stork, marmoset, panpipe, drumstick, knee pad, lotion, french loaf, throne, jeep, jersey, tiger cat, cliff, sealyham terrier, strawberry, minibus, goldfinch, goblet, burrito, harp, tractor, cornet, leopard, fly, fireboat, bolete, barber chair, consomme, tripod, breastplate, pineapple, wok, totem pole, alligator lizard, common iguana, digital clock, bighorn, siamese cat, bobsled, irish setter, zucchini, crock pot, loggerhead, irish wolfhound, nipple, rubber eraser, impala, barbell, snow leopard, siberian husky, necklace, manhole cover, electric fan, hippopotamus, entlebucher, prison, doberman, ruffed grouse, coyote, toaster, puffer, black swan, schipperke, file, prairie chicken, hourglass, greater swiss mountain dog, pajama, ear, pedestal, viaduct, shoji, snowplow, puck, gyromitra, birdhouse, flatworm, pier, coral reef, pot, mortar, polaroid camera, passenger car, barracouta, banded gecko, black-and-tan coonhound, safe, ski, torch, green lizard, volleyball, brambling, solar dish, lawn mower, swing, hyena, staffordshire bullterrier, screw, toilet tissue, velvet, scale, stopwatch, sock, koala, garbage truck, spider monkey, afghan hound, chain, upright, flagpole, tree frog, cuirass, chest, groenendael, christmas stocking, lakeland terrier, perfume, neck brace, lab coat, carbonara, porcupine, shower curtain, slug, pitcher, flat-coated retriever, pekinese, oscilloscope, church, lynx, cowboy hat, table lamp, pug, crate, water buffalo, labrador retriever, weimaraner, giant schnauzer, stove, sea urchin, banjo, tiger, miniskirt, eft, european gallinule, vending machine, miniature schnauzer, maypole, bull mastiff, hoopskirt, coffeepot, four-poster, safety pin, monarch, beer glass, grasshopper, head cabbage, parking meter, bonnet, chiffonier, great dane, spider web, electric locomotive, scotch terrier, australian terrier, honeycomb, leafhopper, beer bottle, mud turtle, lifeboat, cassette, potter's wheel, oystercatcher, space heater, coral fungus, sunglass, quail, triumphal arch, collie, walker hound, bucket, bee, komodo dragon, dugong, gibbon, trailer truck, king crab, cheetah, rifle, stingray, bison, ipod, modem, box turtle, motor scooter, container ship, vestment, dingo, radiator, giant panda, nail, sea slug, indigo bunting, trimaran, jacamar, chimpanzee, comic book, odometer, dishwasher, bolo tie, barn, paddlewheel, appenzeller, great white shark, green snake, jackfruit, llama, whippet, hay, leaf beetle, sombrero, ram, washbasin, cup, wall clock, acorn squash, spotted salamander, boston bull, border terrier, doormat, cicada, kimono, hand blower, ox, meerkat, space shuttle, african hunting dog, violin, artichoke, toucan, bulbul, coucal, red wolf, seat belt, bicycle-built-for-two, bow tie, pretzel, bedlington terrier, albatross, punching bag, cocktail shaker, diamondback, corn, ant, mountain bike, walking stick, standard schnauzer, power drill, cardigan, accordion, wire-haired fox terrier, streetcar, beach wagon, ibizan hound, hair spray, car mirror, mountain tent, trench coat, studio couch, pomeranian, dough, corkscrew, broom, parachute, band aid, water tower, teddy, fire engine, hornbill, hotdog, theater curtain, crane, malinois, lion, african elephant, handkerchief, caldron, shopping basket, gown, wolf spider, vizsla, electric ray, freight car, pembroke, feather boa, wallet, agama, hard disc, stretcher, sorrel, trilobite, basset, vulture, tarantula, hermit crab, king snake, robin, bernese mountain dog, ski mask, fountain pen, combination lock, yurt, clumber, park bench, baboon, kuvasz, centipede, tabby, steam locomotive, badger, irish water spaniel, picket fence, gong, canoe, swimming trunks, submarine, echidna, bib, refrigerator, hammer, lemon, admiral, chihuahua, basenji, pinwheel, golfcart, bullet train, crib, muzzle, eggnog, old english sheepdog, tray, tiger beetle, electric guitar, peacock, soup bowl, wallaby, abacus, dalmatian, harvester, aircraft carrier, snowmobile, welsh springer spaniel, affenpinscher, oboe, cassette player, pencil sharpener, japanese spaniel, plunger, black widow, norfolk terrier, reflex camera, ice bear, redbone, mongoose, warthog, arabian camel, bittern, mixing bowl, tailed frog, scabbard, castle, curly-coated retriever, garden spider, folding chair, mouse, prayer rug, red fox, toy terrier, leonberg, lycaenid, poncho, goldfish, red-backed sandpiper, holster, hair slide, coho, komondor, macaw, maltese dog, megalith, sarong, green mamba, sea lion, water ouzel, bulletproof vest, sulphur-crested cockatoo, scottish deerhound, steel arch bridge, catamaran, brittany spaniel, redshank, otter, brabancon griffon, balloon, rule, planetarium, trombone, mitten, abaya, crash helmet, milk can, hartebeest, windsor tie, irish terrier, african chameleon, matchstick, water bottle, cloak, ground beetle, ashcan, crane, gila monster, unicycle, gazelle, wombat, brain coral, projector, custard apple, proboscis monkey, tibetan mastiff, mosque, plastic bag, backpack, drum, norwich terrier, pizza, carton, plane, gorilla, jigsaw puzzle, forklift, isopod, otterhound, vacuum, european fire salamander, apron, langur, boxer, african grey, ice lolly, toilet seat, golf ball, titi, drake, ostrich, magnetic compass, great pyrenees, rhodesian ridgeback, buckeye, dungeness crab, toy poodle, ptarmigan, amphibian, monitor, school bus, schooner, spatula, weevil, speedboat, sundial, borzoi, bassoon, bath towel, pill bottle, acorn, tick, briard, thimble, brass, white wolf, boathouse, yawl, miniature pinscher, barn spider, jean, water snake, dishrag, yorkshire terrier, hammerhead, typewriter keyboard, papillon, ocarina, washer, standard poodle, china cabinet, steel drum, swab, mobile home, german short-haired pointer, saluki, bee eater, rock python, vine snake, kelpie, harmonica, military uniform, reel, thatch, maraca, tricycle, sidewinder, parallel bars, banana, flute, paintbrush, sleeping bag, yellow lady's slipper, three-toed sloth, white stork, notebook, weasel, tiger shark, football helmet, madagascar cat, dowitcher, wreck, king penguin, lighter, timber wolf, racket, digital watch, liner, hen, suspension bridge, pillow, carpenter's kit, butternut squash, sandal, sussex spaniel, hip, american staffordshire terrier, flamingo, analog clock, black and gold garden spider, sea cucumber, indian elephant, syringe, lens cap, missile, cougar, diaper, chambered nautilus, garter snake, anemone fish, organ, limousine, horse cart, jaguar, frilled lizard, crutch, sea anemone, guenon, meat loaf, slide rule, saltshaker, pomegranate, acoustic guitar, shopping cart, drilling platform, nematode, chickadee, academic gown, candle, norwegian elkhound, armadillo, horizontal bar, orangutan, obelisk, stone wall, cannon, rugby ball, ping-pong ball, window shade, trolleybus, ice cream, pop bottle, cock, harvestman, leatherback turtle, killer whale, spaghetti squash, chain saw, stinkhorn, espresso maker, loafer, bagel, ballplayer, skunk, chainlink fence, earthstar, whiptail, barrel, kerry blue terrier, triceratops, chow, grey fox, sax, binoculars, ladybug, silky terrier, gas pump, cradle, whiskey jug, french bulldog, eskimo dog, hog, hognose snake, pickup, indian cobra, hand-held computer, printer, pole, bald eagle, american alligator, dumbbell, umbrella, mink, shower cap, tank, quill, fox squirrel, ambulance, lesser panda, frying pan, letter opener, hook, strainer, pick, dragonfly, gar, piggy bank, envelope, stole, ibex, american chameleon, bearskin, microwave, petri dish, wood rabbit, beacon, dung beetle, warplane, ruddy turnstone, knot, fur coat, hamper, beagle, ringlet, mask, persian cat, cellular telephone, american coot, apiary, shovel, coffee mug, sewing machine, spoonbill, padlock, bell pepper, great grey owl, squirrel monkey, sulphur butterfly, scoreboard, bow, malamute, siamang, snail, remote control, sea snake, loupe, model t, english setter, dining table, face powder, tench, jack-o'-lantern, croquet ball, water jug, airedale, airliner, guinea pig, hare, damselfly, thresher, limpkin, buckle, english springer, boa constrictor, french horn, black-footed ferret, shetland sheepdog, capuchin, cheeseburger, miniature poodle, spotlight, wooden spoon, west highland white terrier, wig, running shoe, cowboy boot, brown bear, iron, brassiere, magpie, gondola, grand piano, granny smith, mashed potato, german shepherd, stethoscope, cauliflower, soccer ball, pay-phone, jellyfish, cairn, polecat, trifle, photocopier, shih-tzu, orange, guacamole, hatchet, cello, egyptian cat, basketball, moving van, mortarboard, dial telephone, street sign, oil filter, beaver, spiny lobster, chime, bookcase, chiton, black grouse, jay, axolotl, oxygen mask, cricket, worm fence, indri, cockroach, mushroom, dandie dinmont, tennis ball, howler monkey, rapeseed, tibetan terrier, newfoundland, dutch oven, paddle, joystick, golden retriever, blenheim spaniel, mantis, soft-coated wheaten terrier, little blue heron, convertible, bloodhound, palace, medicine chest, english foxhound, cleaver, sweatshirt, mosquito net, soap dispenser, ladle, screwdriver, fire screen, binder, suit, barrow, clog, cucumber, baseball, lorikeet, conch, quilt, eel, horned viper, night snake, angora, pickelhaube, gasmask, patas" + self.classes_50 = ['background'] + self.classes_50.split(', ') + self.classes_300 = ['background'] + self.classes_300.split(', ') + self.classes_919 = ['background'] + self.classes_919.split(', ') + + class_idx = json.load(open("./data/imagenet/imagenet_class_index.json")) + idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))] + + self.map_seg_label_imagenet_label = {} + self.map_imagenet_label_seg_label = {} + adapted_idx_list = [l.replace('_'," ").lower() for l in idx2label] + if '300' in ntpath.basename(name_list): + self.classes = self.classes_300 + self.mode = 300 + elif '50' in ntpath.basename(name_list): + self.classes = self.classes_50 + self.mode = 50 + elif '919' in ntpath.basename(name_list): + self.classes = self.classes_919 + self.mode = 919 + else: + raise Exception("something wrong...") + for i, seg_l in enumerate(self.classes): + if seg_l == 'background': + self.map_seg_label_imagenet_label[i] = -1 + self.map_imagenet_label_seg_label[-1] = i + else: + new_idx = adapted_idx_list.index(seg_l) + self.map_seg_label_imagenet_label[i] = new_idx + self.map_imagenet_label_seg_label[new_idx] = i + + self.prefilter_items = prefilter_items + if self.prefilter_items: + self.idx_mapping = self.do_prefilter_items() + + if self.prefilter_items and (not self.simple_items or not self.use_new_labels): + raise Exception("not implemented") + + self.transform_mask_to_img_classes = transform_mask_to_img_classes + + def do_prefilter_items(self): + #filters: all seg labels must be in imagenet + idx_mapping = {} + counter = 0 + for i in range(len(self.label_lst)): + gt = Image.open(self.label_lst[i]) + gt = np.array(gt) + gt_uint = (gt[:, :, 1] * 256 + gt[:, :, 0]).astype(int) + gt_uint = gt[:, :, 1] * 256 + gt[:, :, 0] + in_segmentation = torch.from_numpy(gt_uint.astype(np.float)) + imagenet_idx = self.imagenet_idcs[i] + p = self.imagenet.samples[imagenet_idx][0] + new_l = self.new_labels_from_file[ntpath.basename(p)] + uniques = np.unique(in_segmentation) + m_unique = [self.map_seg_label_imagenet_label[x] for x in uniques if x < self.mode] + if all(x in m_unique for x in new_l): + for x in range(len(new_l)): + idx_mapping[counter] = (i,x) + counter += 1 + return idx_mapping + + def __getitem__(self, item): + if not self.simple_items: + img_id = self.imagenet_idcs[item] + in_image, in_label = self.imagenet[img_id] + if self.use_new_labels: + in_label = self.get_new_label(img_id) + + if self.transform is not None: + in_image = self.transform(in_image) + gt = Image.open(self.label_lst[item]) + gt = np.array(gt) + gt_uint = gt[:, :, 1] * 256 + gt[:, :, 0] + gt_uint = torch.from_numpy(gt_uint.astype(np.float)) + + if self.transform_mask_to_img_classes: + gt_uint = self.transform_mask(gt_uint) + + if self.transform is not None: + gt_transformed = Image.fromarray(np.uint8(gt)) + gt_transformed = self.transform(gt_transformed) + else: + gt_transformed = gt + + return in_image, gt_uint, gt_transformed, in_label + else: + return self.get_raw(item) + + def get_raw(self, item): + if self.prefilter_items: + is_id, seg_id = self.idx_mapping[item] + else: + is_id = item + img_id = self.imagenet_idcs[is_id] + in_image, in_label = self.imagenet[img_id] + if self.use_new_labels: + in_label = self.get_new_label(img_id) + in_image = np.array(in_image) + gt = Image.open(self.label_lst[is_id]) + gt = np.array(gt) + gt_uint = (gt[:, :, 1] * 256 + gt[:, :, 0]).astype(int) + if self.transform_mask_to_img_classes: + gt_uint = self.transform_mask(gt_uint) + + if self.prefilter_items: + #assert seg_id in in_image + return in_image, gt_uint, in_label[seg_id] + else: + return in_image, gt_uint, in_label + + def transform_mask(self, mask): + uniques = np.unique(mask) + new_m = np.zeros_like(mask) + for x in uniques: + if x >= self.mode: + new_m[mask == x] = -1 + continue + new_m[mask == x] = self.map_seg_label_imagenet_label[x] + return new_m + + def get_new_label_generic(self, item): + if self.prefilter_items: + is_id, seg_id = self.idx_mapping[item] + else: + is_id = item + img_id = self.imagenet_idcs[is_id] + return self.get_new_label(img_id) + + def get_new_label(self, imagenet_idx): + p = self.imagenet.samples[imagenet_idx][0] + return self.new_labels_from_file[ntpath.basename(p)] + + def get_imagent_id(self, item): + if self.prefilter_items: + is_id, seg_id = self.idx_mapping[item] + else: + is_id = item + img_id = self.imagenet_idcs[is_id] + return img_id + + def __len__(self): + if self.prefilter_items: + return len(self.idx_mapping.keys()) + else: + return len(self.label_lst) + +def get_param(mode): + assert mode in ['50', '300', '919'], 'invalid dataset' + params = { + '50': {'num_classes': 50, + 'classes': 'classes_50', + 'dir': 'ImageNetS50', + 'names': 'ImageNetS_im50_validation.txt'}, + '300': {'num_classes': 300, + 'classes': 'classes_300', + 'dir': 'ImageNetS300', + 'names': 'ImageNetS_im300_validation.txt'}, + '919': {'num_classes': 919, + 'classes': 'classes_919', + 'dir': 'ImageNetS919', + 'names': 'ImageNetS_im919_validation.txt'}, + } + + return params[mode] \ No newline at end of file diff --git a/shifthappens/tasks/lost_in_translation/lost_in_translation.py b/shifthappens/tasks/lost_in_translation/lost_in_translation.py new file mode 100644 index 00000000..f35e77c9 --- /dev/null +++ b/shifthappens/tasks/lost_in_translation/lost_in_translation.py @@ -0,0 +1,208 @@ +""" +TODO +""" +import dataclasses +import os + +import numpy as np +import torchvision.datasets as tv_datasets +import torchvision.transforms as tv_transforms + +import shifthappens.data.base as sh_data +import shifthappens.data.torch as sh_data_torch +import shifthappens.utils as sh_utils +from shifthappens import benchmark as sh_benchmark +from shifthappens.data import imagenet as sh_imagenet +from shifthappens.models import base as sh_models +from shifthappens.models import torchvision as sh_models_t +from shifthappens.models.base import PredictionTargets +from shifthappens.tasks.base import Task +from shifthappens.tasks.metrics import Metric +from shifthappens.tasks.mixins import OODScoreTaskMixin +from shifthappens.tasks.task_result import TaskResult +from shifthappens.tasks.utils import auroc_ood +from shifthappens.tasks.utils import fpr_at_tpr +from shifthappens.tasks.base import parameter + +import shifthappens.tasks.lost_in_translation.affine_transformations.affine as a +import shifthappens.tasks.lost_in_translation.affine_transformations.affine_linspace as lin +import shifthappens.tasks.lost_in_translation.affine_transformations.affine_linspace_adaptive as lin_a +import shifthappens.tasks.lost_in_translation.affine_transformations.statistics as stat + +from torchvision import transforms +import shifthappens.tasks.lost_in_translation.imagenet_s.imagenet_s as i_s +import shifthappens.config +import random +import math + + +@sh_benchmark.register_task( + name="LostInTranslation", relative_data_folder="lost_in_translation", standalone=True +) +@dataclasses.dataclass +class LostInTranslationBase(Task): + """ + TODO + """ + + default_batch_size: int = 700 + + resolution: int = parameter( + default=224, + description="resolution expected from the model", + ) + + rotation_cutoff: int = parameter( + default=0, + description="restrict the dataset to samples with at least these degrees of rotation", + ) + + translation_cutoff: int = parameter( + default=0, + description="restrict the dataset to samples with at least these x pixels translatin freedom", + ) + + + resource_s = ( + "imagenet_s", + #"raccoons.tar.gz", + #"https://nc.mlcloud.uni-tuebingen.de/index.php/s/JrSQeRgXfw28crC/download/raccoons.tar.gz", + None, + None, + None, + ) + + def setup(self): + """Load and prepare the data.""" + + folder_name_s, file_name, url, md5 = self.resource_s + imagent_s_folder = os.path.join(self.data_root, folder_name_s) + # if not os.path.exists(dataset_folder): + # sh_utils.download_and_extract_archive(url, dataset_folder, md5, file_name) + + a.config = a.config_imagenet + a.config['target_size'] = self.resolution + a.config['crop_size'] = self.resolution + + + tt = transforms.ToTensor() + + imagenet_root_path = shifthappens.config.imagenet_validation_path + + params = i_s.get_param('300') + num_classes = params['num_classes'] + + name_list = os.path.join(f'{imagent_s_folder}/names', params['names']) + + subdir = 'validation-segmentation' + gt_dir = os.path.join(imagent_s_folder, params['dir'], subdir) + self.dataset = i_s.ImageNetSEvalDataset(imagenet_root_path, gt_dir, name_list, transform=tt, + use_new_labels=True, simple_items=True, prefilter_items=True, + transform_mask_to_img_classes=True) + + label_map = {} + for i in range(len(self.dataset)): + I, gt_uint, in_label = self.dataset[i] + label_map[i] = in_label + + stat.label_map = label_map + + def _evaluate(self, model: sh_models.Model) -> TaskResult: + data = [] + idx_list = list(range(len(self.dataset))) + random.shuffle(idx_list) + for i in idx_list: + try: + I, gt_uint, in_label = self.dataset[i] + loaded = a.load_and_pad_imagenet(I, gt_uint, in_label) + element = a.to_torch(*loaded) + data.append((element, in_label,i)) + except AssertionError as e: + raise e + except: + pass + + if isinstance(model, sh_models_t.__TorchvisionModel): + batch_size_model = model.max_batch_size + eval_device = model.device + m = model.model + res = self._eval_model(m, "anon", batch_size_model, data, eval_device) + return TaskResult( + accuracy_base=res["base_case"], + accuracy_worst=res["worst_case"], + summary_metrics={ + Metric.Robustness: "accuracy_base", + Metric.Robustness: "accuracy_worst", + }, + ) + else: + raise Exception("TODO!") + +@sh_benchmark.register_task( + name="LostInTranslationTranslation", relative_data_folder="lost_in_translation", standalone=True +) +@dataclasses.dataclass +class LostInTranslationTranslation(LostInTranslationBase): + def _eval_model(model, model_name, batch_size_model, data, eval_device): + + args = (model, model_name, data, eval_device, batch_size_model) + + #print_base_worst(model, results_rotation, "rotation", "rotation") + + #results_rotation_filtered = [d for d in results_rotation if stat.gt_30(d)] + + #print_base_worst(model, results_rotation_filtered, "rotation deg>=30", "rotation") + + resolutions = [25,25,25] + num_points=3 + + adapt_grid_to_min=True + + #TODO: Make this batch_size indipendent/configurable + + # perc = 0.5 + # resolutions = [1.0,0.5,0.2] + rec_depth = 2 + num_points = 2 + sizes=[4.6,2.5] + sqrt_bl = math.sqrt(batch_size_model) + resolutions = [ + 2.0, + (sizes[0] + 1) / (sqrt_bl - 1), + (sizes[1] + 1) / (sqrt_bl - 1)] + + leng_center_2=2.3 + step_size_center_2= (leng_center_2 + 1) / (math.sqrt(batch_size_model) - 1) + + results_translation = lin_a.translation_linspace_adaptive(*args, batch_size_model, + resolutions, num_points, size_recursive=sizes, num_recursion=rec_depth, idx_fun=lambda x:x, target_zoom=0.8, + save_dir = None, step_size_center=step_size_center_2, + leng_center=leng_center_2, adapt_grid_to_min=True, early_stopping=True, + find_min_correct=False, adapt_resolution=True) + + stat_translation = { + "base_case": stat.adaptive_base_case(results_translation, "trans"), + "worst_case": stat.adaptive_worst_case(results_translation, "trans"), + } + + return stat_translation + +@sh_benchmark.register_task( + name="LostInTranslationRotation", relative_data_folder="lost_in_translation", standalone=True +) +@dataclasses.dataclass +class LostInTranslationRotation(LostInTranslationBase): + def _eval_model(model, model_name, batch_size_model, data, eval_device): + + args = (model, model_name, data, eval_device, batch_size_model) + + + results_rotation = lin.rotation_linspace(*args, batch_size_rotation=batch_size_model, resolution=200, do_resize=False, save_dir=None) + + + stat_rotation = { + "base_case": stat.adaptive_base_case(results_rotation, "rotation"), + "worst_case": stat.adaptive_worst_case(results_rotation, "rotation"), + } + + return stat_rotation \ No newline at end of file