module @compiled_attn { util.global private @__auto.attn.img_mod.lin.weight = #stream.parameter.named<"model"::"attn.img_mod.lin.weight"> : tensor<18432x3072xf16> util.global private @__auto.attn.img_mod.lin.bias = #stream.parameter.named<"model"::"attn.img_mod.lin.bias"> : tensor<18432xf16> util.global private @__auto.attn.txt_mod.lin.weight = #stream.parameter.named<"model"::"attn.txt_mod.lin.weight"> : tensor<18432x3072xf16> util.global private @__auto.attn.txt_mod.lin.bias = #stream.parameter.named<"model"::"attn.txt_mod.lin.bias"> : tensor<18432xf16> util.global private @__auto.attn.img_attn.qkv.weight = #stream.parameter.named<"model"::"attn.img_attn.qkv.weight"> : tensor<9216x3072xf16> util.global private @__auto.attn.img_attn.qkv.bias = #stream.parameter.named<"model"::"attn.img_attn.qkv.bias"> : tensor<9216xf16> util.global private @__auto.attn.img_attn.norm.query_norm.scale = #stream.parameter.named<"model"::"attn.img_attn.norm.query_norm.scale"> : tensor<128xf16> util.global private @__auto.attn.img_attn.norm.key_norm.scale = #stream.parameter.named<"model"::"attn.img_attn.norm.key_norm.scale"> : tensor<128xf16> util.global private @__auto.attn.txt_attn.qkv.weight = #stream.parameter.named<"model"::"attn.txt_attn.qkv.weight"> : tensor<9216x3072xf16> util.global private @__auto.attn.txt_attn.qkv.bias = #stream.parameter.named<"model"::"attn.txt_attn.qkv.bias"> : tensor<9216xf16> util.global private @__auto.attn.txt_attn.norm.query_norm.scale = #stream.parameter.named<"model"::"attn.txt_attn.norm.query_norm.scale"> : tensor<128xf16> util.global private @__auto.attn.txt_attn.norm.key_norm.scale = #stream.parameter.named<"model"::"attn.txt_attn.norm.key_norm.scale"> : tensor<128xf16> util.global private @__auto.attn.img_attn.proj.weight = #stream.parameter.named<"model"::"attn.img_attn.proj.weight"> : tensor<3072x3072xf16> util.global private @__auto.attn.img_attn.proj.bias = #stream.parameter.named<"model"::"attn.img_attn.proj.bias"> : tensor<3072xf16> util.global private @__auto.attn.img_mlp.0.weight = #stream.parameter.named<"model"::"attn.img_mlp.0.weight"> : tensor<12288x3072xf16> util.global private @__auto.attn.img_mlp.0.bias = #stream.parameter.named<"model"::"attn.img_mlp.0.bias"> : tensor<12288xf16> util.global private @__auto.attn.img_mlp.2.weight = #stream.parameter.named<"model"::"attn.img_mlp.2.weight"> : tensor<3072x12288xf16> util.global private @__auto.attn.img_mlp.2.bias = #stream.parameter.named<"model"::"attn.img_mlp.2.bias"> : tensor<3072xf16> util.global private @__auto.attn.txt_attn.proj.weight = #stream.parameter.named<"model"::"attn.txt_attn.proj.weight"> : tensor<3072x3072xf16> util.global private @__auto.attn.txt_attn.proj.bias = #stream.parameter.named<"model"::"attn.txt_attn.proj.bias"> : tensor<3072xf16> util.global private @__auto.attn.txt_mlp.0.weight = #stream.parameter.named<"model"::"attn.txt_mlp.0.weight"> : tensor<12288x3072xf16> util.global private @__auto.attn.txt_mlp.0.bias = #stream.parameter.named<"model"::"attn.txt_mlp.0.bias"> : tensor<12288xf16> util.global private @__auto.attn.txt_mlp.2.weight = #stream.parameter.named<"model"::"attn.txt_mlp.2.weight"> : tensor<3072x12288xf16> util.global private @__auto.attn.txt_mlp.2.bias = #stream.parameter.named<"model"::"attn.txt_mlp.2.bias"> : tensor<3072xf16> func.func @main(%arg0: !torch.vtensor<[1,4096,3072],f16>, %arg1: !torch.vtensor<[1,512,3072],f16>, %arg2: !torch.vtensor<[1,3072],f16>, %arg3: !torch.vtensor<[1,1,4608,64,2,2],f32>) -> (!torch.vtensor<[1,4096,3072],f16>, !torch.vtensor<[1,512,3072],f16>) attributes {torch.assume_strict_symbolic_shapes} { %0 = torch.aten.silu %arg2 : !torch.vtensor<[1,3072],f16> -> !torch.vtensor<[1,3072],f16> %__auto.attn.img_mod.lin.weight = util.global.load @__auto.attn.img_mod.lin.weight : tensor<18432x3072xf16> %1 = torch_c.from_builtin_tensor %__auto.attn.img_mod.lin.weight : tensor<18432x3072xf16> -> !torch.vtensor<[18432,3072],f16> %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 %2 = torch.aten.transpose.int %1, %int0, %int1 : !torch.vtensor<[18432,3072],f16>, !torch.int, !torch.int -> !torch.vtensor<[3072,18432],f16> %__auto.attn.img_mod.lin.bias = util.global.load @__auto.attn.img_mod.lin.bias : tensor<18432xf16> %3 = torch_c.from_builtin_tensor %__auto.attn.img_mod.lin.bias : tensor<18432xf16> -> !torch.vtensor<[18432],f16> %int6 = torch.constant.int 6 %4 = torch.prims.convert_element_type %3, %int6 : !torch.vtensor<[18432],f16>, !torch.int -> !torch.vtensor<[18432],f32> %int6_0 = torch.constant.int 6 %5 = torch.prims.convert_element_type %0, %int6_0 : !torch.vtensor<[1,3072],f16>, !torch.int -> !torch.vtensor<[1,3072],f32> %int6_1 = torch.constant.int 6 %6 = torch.prims.convert_element_type %2, %int6_1 : !torch.vtensor<[3072,18432],f16>, !torch.int -> !torch.vtensor<[3072,18432],f32> %7 = torch.aten.mm %5, %6 : !torch.vtensor<[1,3072],f32>, !torch.vtensor<[3072,18432],f32> -> !torch.vtensor<[1,18432],f32> %int1_2 = torch.constant.int 1 %8 = torch.aten.mul.Scalar %7, %int1_2 : !torch.vtensor<[1,18432],f32>, !torch.int -> !torch.vtensor<[1,18432],f32> %int1_3 = torch.constant.int 1 %9 = torch.aten.mul.Scalar %4, %int1_3 : !torch.vtensor<[18432],f32>, !torch.int -> !torch.vtensor<[18432],f32> %int1_4 = torch.constant.int 1 %10 = torch.aten.add.Tensor %8, %9, %int1_4 : !torch.vtensor<[1,18432],f32>, !torch.vtensor<[18432],f32>, !torch.int -> !torch.vtensor<[1,18432],f32> %int5 = torch.constant.int 5 %11 = torch.prims.convert_element_type %10, %int5 : !torch.vtensor<[1,18432],f32>, !torch.int -> !torch.vtensor<[1,18432],f16> %int0_5 = torch.constant.int 0 %int0_6 = torch.constant.int 0 %int9223372036854775807 = torch.constant.int 9223372036854775807 %int1_7 = torch.constant.int 1 %12 = torch.aten.slice.Tensor %11, %int0_5, %int0_6, %int9223372036854775807, %int1_7 : !torch.vtensor<[1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,18432],f16> %int1_8 = torch.constant.int 1 %13 = torch.aten.unsqueeze %12, %int1_8 : !torch.vtensor<[1,18432],f16>, !torch.int -> !torch.vtensor<[1,1,18432],f16> %int2 = torch.constant.int 2 %int0_9 = torch.constant.int 0 %int9223372036854775807_10 = torch.constant.int 9223372036854775807 %int1_11 = torch.constant.int 1 %14 = torch.aten.slice.Tensor %13, %int2, %int0_9, %int9223372036854775807_10, %int1_11 : !torch.vtensor<[1,1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,18432],f16> %int-1 = torch.constant.int -1 %int0_12 = torch.constant.int 0 %int3072 = torch.constant.int 3072 %int1_13 = torch.constant.int 1 %15 = torch.aten.slice.Tensor %14, %int-1, %int0_12, %int3072, %int1_13 : !torch.vtensor<[1,1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %int-1_14 = torch.constant.int -1 %int3072_15 = torch.constant.int 3072 %int6144 = torch.constant.int 6144 %int1_16 = torch.constant.int 1 %16 = torch.aten.slice.Tensor %14, %int-1_14, %int3072_15, %int6144, %int1_16 : !torch.vtensor<[1,1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %int-1_17 = torch.constant.int -1 %int6144_18 = torch.constant.int 6144 %int9216 = torch.constant.int 9216 %int1_19 = torch.constant.int 1 %17 = torch.aten.slice.Tensor %14, %int-1_17, %int6144_18, %int9216, %int1_19 : !torch.vtensor<[1,1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %int-1_20 = torch.constant.int -1 %int9216_21 = torch.constant.int 9216 %int12288 = torch.constant.int 12288 %int1_22 = torch.constant.int 1 %18 = torch.aten.slice.Tensor %14, %int-1_20, %int9216_21, %int12288, %int1_22 : !torch.vtensor<[1,1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %int-1_23 = torch.constant.int -1 %int12288_24 = torch.constant.int 12288 %int15360 = torch.constant.int 15360 %int1_25 = torch.constant.int 1 %19 = torch.aten.slice.Tensor %14, %int-1_23, %int12288_24, %int15360, %int1_25 : !torch.vtensor<[1,1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %int-1_26 = torch.constant.int -1 %int15360_27 = torch.constant.int 15360 %int18432 = torch.constant.int 18432 %int1_28 = torch.constant.int 1 %20 = torch.aten.slice.Tensor %14, %int-1_26, %int15360_27, %int18432, %int1_28 : !torch.vtensor<[1,1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %21 = torch.aten.silu %arg2 : !torch.vtensor<[1,3072],f16> -> !torch.vtensor<[1,3072],f16> %__auto.attn.txt_mod.lin.weight = util.global.load @__auto.attn.txt_mod.lin.weight : tensor<18432x3072xf16> %22 = torch_c.from_builtin_tensor %__auto.attn.txt_mod.lin.weight : tensor<18432x3072xf16> -> !torch.vtensor<[18432,3072],f16> %int0_29 = torch.constant.int 0 %int1_30 = torch.constant.int 1 %23 = torch.aten.transpose.int %22, %int0_29, %int1_30 : !torch.vtensor<[18432,3072],f16>, !torch.int, !torch.int -> !torch.vtensor<[3072,18432],f16> %__auto.attn.txt_mod.lin.bias = util.global.load @__auto.attn.txt_mod.lin.bias : tensor<18432xf16> %24 = torch_c.from_builtin_tensor %__auto.attn.txt_mod.lin.bias : tensor<18432xf16> -> !torch.vtensor<[18432],f16> %int6_31 = torch.constant.int 6 %25 = torch.prims.convert_element_type %24, %int6_31 : !torch.vtensor<[18432],f16>, !torch.int -> !torch.vtensor<[18432],f32> %int6_32 = torch.constant.int 6 %26 = torch.prims.convert_element_type %21, %int6_32 : !torch.vtensor<[1,3072],f16>, !torch.int -> !torch.vtensor<[1,3072],f32> %int6_33 = torch.constant.int 6 %27 = torch.prims.convert_element_type %23, %int6_33 : !torch.vtensor<[3072,18432],f16>, !torch.int -> !torch.vtensor<[3072,18432],f32> %28 = torch.aten.mm %26, %27 : !torch.vtensor<[1,3072],f32>, !torch.vtensor<[3072,18432],f32> -> !torch.vtensor<[1,18432],f32> %int1_34 = torch.constant.int 1 %29 = torch.aten.mul.Scalar %28, %int1_34 : !torch.vtensor<[1,18432],f32>, !torch.int -> !torch.vtensor<[1,18432],f32> %int1_35 = torch.constant.int 1 %30 = torch.aten.mul.Scalar %25, %int1_35 : !torch.vtensor<[18432],f32>, !torch.int -> !torch.vtensor<[18432],f32> %int1_36 = torch.constant.int 1 %31 = torch.aten.add.Tensor %29, %30, %int1_36 : !torch.vtensor<[1,18432],f32>, !torch.vtensor<[18432],f32>, !torch.int -> !torch.vtensor<[1,18432],f32> %int5_37 = torch.constant.int 5 %32 = torch.prims.convert_element_type %31, %int5_37 : !torch.vtensor<[1,18432],f32>, !torch.int -> !torch.vtensor<[1,18432],f16> %int0_38 = torch.constant.int 0 %int0_39 = torch.constant.int 0 %int9223372036854775807_40 = torch.constant.int 9223372036854775807 %int1_41 = torch.constant.int 1 %33 = torch.aten.slice.Tensor %32, %int0_38, %int0_39, %int9223372036854775807_40, %int1_41 : !torch.vtensor<[1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,18432],f16> %int1_42 = torch.constant.int 1 %34 = torch.aten.unsqueeze %33, %int1_42 : !torch.vtensor<[1,18432],f16>, !torch.int -> !torch.vtensor<[1,1,18432],f16> %int2_43 = torch.constant.int 2 %int0_44 = torch.constant.int 0 %int9223372036854775807_45 = torch.constant.int 9223372036854775807 %int1_46 = torch.constant.int 1 %35 = torch.aten.slice.Tensor %34, %int2_43, %int0_44, %int9223372036854775807_45, %int1_46 : !torch.vtensor<[1,1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,18432],f16> %int-1_47 = torch.constant.int -1 %int0_48 = torch.constant.int 0 %int3072_49 = torch.constant.int 3072 %int1_50 = torch.constant.int 1 %36 = torch.aten.slice.Tensor %35, %int-1_47, %int0_48, %int3072_49, %int1_50 : !torch.vtensor<[1,1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %int-1_51 = torch.constant.int -1 %int3072_52 = torch.constant.int 3072 %int6144_53 = torch.constant.int 6144 %int1_54 = torch.constant.int 1 %37 = torch.aten.slice.Tensor %35, %int-1_51, %int3072_52, %int6144_53, %int1_54 : !torch.vtensor<[1,1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %int-1_55 = torch.constant.int -1 %int6144_56 = torch.constant.int 6144 %int9216_57 = torch.constant.int 9216 %int1_58 = torch.constant.int 1 %38 = torch.aten.slice.Tensor %35, %int-1_55, %int6144_56, %int9216_57, %int1_58 : !torch.vtensor<[1,1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %int-1_59 = torch.constant.int -1 %int9216_60 = torch.constant.int 9216 %int12288_61 = torch.constant.int 12288 %int1_62 = torch.constant.int 1 %39 = torch.aten.slice.Tensor %35, %int-1_59, %int9216_60, %int12288_61, %int1_62 : !torch.vtensor<[1,1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %int-1_63 = torch.constant.int -1 %int12288_64 = torch.constant.int 12288 %int15360_65 = torch.constant.int 15360 %int1_66 = torch.constant.int 1 %40 = torch.aten.slice.Tensor %35, %int-1_63, %int12288_64, %int15360_65, %int1_66 : !torch.vtensor<[1,1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %int-1_67 = torch.constant.int -1 %int15360_68 = torch.constant.int 15360 %int18432_69 = torch.constant.int 18432 %int1_70 = torch.constant.int 1 %41 = torch.aten.slice.Tensor %35, %int-1_67, %int15360_68, %int18432_69, %int1_70 : !torch.vtensor<[1,1,18432],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %int6_71 = torch.constant.int 6 %42 = torch.prims.convert_element_type %arg0, %int6_71 : !torch.vtensor<[1,4096,3072],f16>, !torch.int -> !torch.vtensor<[1,4096,3072],f32> %int2_72 = torch.constant.int 2 %43 = torch.prim.ListConstruct %int2_72 : (!torch.int) -> !torch.list %int0_73 = torch.constant.int 0 %true = torch.constant.bool true %result0, %result1 = torch.aten.var_mean.correction %42, %43, %int0_73, %true : !torch.vtensor<[1,4096,3072],f32>, !torch.list, !torch.int, !torch.bool -> !torch.vtensor<[1,4096,1],f32>, !torch.vtensor<[1,4096,1],f32> %float9.999990e-07 = torch.constant.float 9.9999999999999995E-7 %int1_74 = torch.constant.int 1 %44 = torch.aten.add.Scalar %result0, %float9.999990e-07, %int1_74 : !torch.vtensor<[1,4096,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,4096,1],f32> %45 = torch.aten.rsqrt %44 : !torch.vtensor<[1,4096,1],f32> -> !torch.vtensor<[1,4096,1],f32> %int1_75 = torch.constant.int 1 %46 = torch.aten.sub.Tensor %arg0, %result1, %int1_75 : !torch.vtensor<[1,4096,3072],f16>, !torch.vtensor<[1,4096,1],f32>, !torch.int -> !torch.vtensor<[1,4096,3072],f32> %47 = torch.aten.mul.Tensor %46, %45 : !torch.vtensor<[1,4096,3072],f32>, !torch.vtensor<[1,4096,1],f32> -> !torch.vtensor<[1,4096,3072],f32> %int5_76 = torch.constant.int 5 %48 = torch.prims.convert_element_type %47, %int5_76 : !torch.vtensor<[1,4096,3072],f32>, !torch.int -> !torch.vtensor<[1,4096,3072],f16> %int1_77 = torch.constant.int 1 %int1_78 = torch.constant.int 1 %49 = torch.aten.add.Scalar %16, %int1_77, %int1_78 : !torch.vtensor<[1,1,3072],f16>, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %50 = torch.aten.mul.Tensor %49, %48 : !torch.vtensor<[1,1,3072],f16>, !torch.vtensor<[1,4096,3072],f16> -> !torch.vtensor<[1,4096,3072],f16> %int1_79 = torch.constant.int 1 %51 = torch.aten.add.Tensor %50, %15, %int1_79 : !torch.vtensor<[1,4096,3072],f16>, !torch.vtensor<[1,1,3072],f16>, !torch.int -> !torch.vtensor<[1,4096,3072],f16> %int4096 = torch.constant.int 4096 %int3072_80 = torch.constant.int 3072 %52 = torch.prim.ListConstruct %int4096, %int3072_80 : (!torch.int, !torch.int) -> !torch.list %53 = torch.aten.view %51, %52 : !torch.vtensor<[1,4096,3072],f16>, !torch.list -> !torch.vtensor<[4096,3072],f16> %__auto.attn.img_attn.qkv.weight = util.global.load @__auto.attn.img_attn.qkv.weight : tensor<9216x3072xf16> %54 = torch_c.from_builtin_tensor %__auto.attn.img_attn.qkv.weight : tensor<9216x3072xf16> -> !torch.vtensor<[9216,3072],f16> %int0_81 = torch.constant.int 0 %int1_82 = torch.constant.int 1 %55 = torch.aten.transpose.int %54, %int0_81, %int1_82 : !torch.vtensor<[9216,3072],f16>, !torch.int, !torch.int -> !torch.vtensor<[3072,9216],f16> %__auto.attn.img_attn.qkv.bias = util.global.load @__auto.attn.img_attn.qkv.bias : tensor<9216xf16> %56 = torch_c.from_builtin_tensor %__auto.attn.img_attn.qkv.bias : tensor<9216xf16> -> !torch.vtensor<[9216],f16> %int6_83 = torch.constant.int 6 %57 = torch.prims.convert_element_type %56, %int6_83 : !torch.vtensor<[9216],f16>, !torch.int -> !torch.vtensor<[9216],f32> %int6_84 = torch.constant.int 6 %58 = torch.prims.convert_element_type %53, %int6_84 : !torch.vtensor<[4096,3072],f16>, !torch.int -> !torch.vtensor<[4096,3072],f32> %int6_85 = torch.constant.int 6 %59 = torch.prims.convert_element_type %55, %int6_85 : !torch.vtensor<[3072,9216],f16>, !torch.int -> !torch.vtensor<[3072,9216],f32> %60 = torch.aten.mm %58, %59 : !torch.vtensor<[4096,3072],f32>, !torch.vtensor<[3072,9216],f32> -> !torch.vtensor<[4096,9216],f32> %int1_86 = torch.constant.int 1 %61 = torch.aten.mul.Scalar %60, %int1_86 : !torch.vtensor<[4096,9216],f32>, !torch.int -> !torch.vtensor<[4096,9216],f32> %int1_87 = torch.constant.int 1 %62 = torch.aten.mul.Scalar %57, %int1_87 : !torch.vtensor<[9216],f32>, !torch.int -> !torch.vtensor<[9216],f32> %int1_88 = torch.constant.int 1 %63 = torch.aten.add.Tensor %61, %62, %int1_88 : !torch.vtensor<[4096,9216],f32>, !torch.vtensor<[9216],f32>, !torch.int -> !torch.vtensor<[4096,9216],f32> %int5_89 = torch.constant.int 5 %64 = torch.prims.convert_element_type %63, %int5_89 : !torch.vtensor<[4096,9216],f32>, !torch.int -> !torch.vtensor<[4096,9216],f16> %int1_90 = torch.constant.int 1 %int4096_91 = torch.constant.int 4096 %int9216_92 = torch.constant.int 9216 %65 = torch.prim.ListConstruct %int1_90, %int4096_91, %int9216_92 : (!torch.int, !torch.int, !torch.int) -> !torch.list %66 = torch.aten.view %64, %65 : !torch.vtensor<[4096,9216],f16>, !torch.list -> !torch.vtensor<[1,4096,9216],f16> %67 = torch_c.to_builtin_tensor %66 : !torch.vtensor<[1,4096,9216],f16> -> tensor<1x4096x9216xf16> %cast = tensor.cast %67 : tensor<1x4096x9216xf16> to tensor %c0 = arith.constant 0 : index %dim = tensor.dim %cast, %c0 : tensor %c1 = arith.constant 1 : index %dim_93 = tensor.dim %cast, %c1 : tensor %c2 = arith.constant 2 : index %dim_94 = tensor.dim %cast, %c2 : tensor flow.tensor.trace "img_qkv" = [%cast : tensor{%dim, %dim_93, %dim_94}] %cast_95 = tensor.cast %cast : tensor to tensor<1x4096x9216xf16> %68 = torch_c.from_builtin_tensor %cast_95 : tensor<1x4096x9216xf16> -> !torch.vtensor<[1,4096,9216],f16> %int1_96 = torch.constant.int 1 %int4096_97 = torch.constant.int 4096 %int3 = torch.constant.int 3 %int24 = torch.constant.int 24 %int128 = torch.constant.int 128 %69 = torch.prim.ListConstruct %int1_96, %int4096_97, %int3, %int24, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %70 = torch.aten.view %68, %69 : !torch.vtensor<[1,4096,9216],f16>, !torch.list -> !torch.vtensor<[1,4096,3,24,128],f16> %int2_98 = torch.constant.int 2 %int0_99 = torch.constant.int 0 %int3_100 = torch.constant.int 3 %int1_101 = torch.constant.int 1 %int4 = torch.constant.int 4 %71 = torch.prim.ListConstruct %int2_98, %int0_99, %int3_100, %int1_101, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %72 = torch.aten.permute %70, %71 : !torch.vtensor<[1,4096,3,24,128],f16>, !torch.list -> !torch.vtensor<[3,1,24,4096,128],f16> %int0_102 = torch.constant.int 0 %int0_103 = torch.constant.int 0 %73 = torch.aten.select.int %72, %int0_102, %int0_103 : !torch.vtensor<[3,1,24,4096,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[1,24,4096,128],f16> %int6_104 = torch.constant.int 6 %74 = torch.prims.convert_element_type %73, %int6_104 : !torch.vtensor<[1,24,4096,128],f16>, !torch.int -> !torch.vtensor<[1,24,4096,128],f32> %int2_105 = torch.constant.int 2 %75 = torch.aten.pow.Tensor_Scalar %74, %int2_105 : !torch.vtensor<[1,24,4096,128],f32>, !torch.int -> !torch.vtensor<[1,24,4096,128],f32> %int-1_106 = torch.constant.int -1 %76 = torch.prim.ListConstruct %int-1_106 : (!torch.int) -> !torch.list %true_107 = torch.constant.bool true %none = torch.constant.none %77 = torch.aten.mean.dim %75, %76, %true_107, %none : !torch.vtensor<[1,24,4096,128],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,24,4096,1],f32> %float9.999990e-07_108 = torch.constant.float 9.9999999999999995E-7 %int1_109 = torch.constant.int 1 %78 = torch.aten.add.Scalar %77, %float9.999990e-07_108, %int1_109 : !torch.vtensor<[1,24,4096,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,24,4096,1],f32> %79 = torch.aten.rsqrt %78 : !torch.vtensor<[1,24,4096,1],f32> -> !torch.vtensor<[1,24,4096,1],f32> %80 = torch.aten.mul.Tensor %74, %79 : !torch.vtensor<[1,24,4096,128],f32>, !torch.vtensor<[1,24,4096,1],f32> -> !torch.vtensor<[1,24,4096,128],f32> %int5_110 = torch.constant.int 5 %81 = torch.prims.convert_element_type %80, %int5_110 : !torch.vtensor<[1,24,4096,128],f32>, !torch.int -> !torch.vtensor<[1,24,4096,128],f16> %__auto.attn.img_attn.norm.query_norm.scale = util.global.load @__auto.attn.img_attn.norm.query_norm.scale : tensor<128xf16> %82 = torch_c.from_builtin_tensor %__auto.attn.img_attn.norm.query_norm.scale : tensor<128xf16> -> !torch.vtensor<[128],f16> %83 = torch.aten.mul.Tensor %81, %82 : !torch.vtensor<[1,24,4096,128],f16>, !torch.vtensor<[128],f16> -> !torch.vtensor<[1,24,4096,128],f16> %int1_111 = torch.constant.int 1 %int4096_112 = torch.constant.int 4096 %int3_113 = torch.constant.int 3 %int24_114 = torch.constant.int 24 %int128_115 = torch.constant.int 128 %84 = torch.prim.ListConstruct %int1_111, %int4096_112, %int3_113, %int24_114, %int128_115 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %85 = torch.aten.view %68, %84 : !torch.vtensor<[1,4096,9216],f16>, !torch.list -> !torch.vtensor<[1,4096,3,24,128],f16> %int2_116 = torch.constant.int 2 %int0_117 = torch.constant.int 0 %int3_118 = torch.constant.int 3 %int1_119 = torch.constant.int 1 %int4_120 = torch.constant.int 4 %86 = torch.prim.ListConstruct %int2_116, %int0_117, %int3_118, %int1_119, %int4_120 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %87 = torch.aten.permute %85, %86 : !torch.vtensor<[1,4096,3,24,128],f16>, !torch.list -> !torch.vtensor<[3,1,24,4096,128],f16> %int0_121 = torch.constant.int 0 %int1_122 = torch.constant.int 1 %88 = torch.aten.select.int %87, %int0_121, %int1_122 : !torch.vtensor<[3,1,24,4096,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[1,24,4096,128],f16> %int6_123 = torch.constant.int 6 %89 = torch.prims.convert_element_type %88, %int6_123 : !torch.vtensor<[1,24,4096,128],f16>, !torch.int -> !torch.vtensor<[1,24,4096,128],f32> %int2_124 = torch.constant.int 2 %90 = torch.aten.pow.Tensor_Scalar %89, %int2_124 : !torch.vtensor<[1,24,4096,128],f32>, !torch.int -> !torch.vtensor<[1,24,4096,128],f32> %int-1_125 = torch.constant.int -1 %91 = torch.prim.ListConstruct %int-1_125 : (!torch.int) -> !torch.list %true_126 = torch.constant.bool true %none_127 = torch.constant.none %92 = torch.aten.mean.dim %90, %91, %true_126, %none_127 : !torch.vtensor<[1,24,4096,128],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,24,4096,1],f32> %float9.999990e-07_128 = torch.constant.float 9.9999999999999995E-7 %int1_129 = torch.constant.int 1 %93 = torch.aten.add.Scalar %92, %float9.999990e-07_128, %int1_129 : !torch.vtensor<[1,24,4096,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,24,4096,1],f32> %94 = torch.aten.rsqrt %93 : !torch.vtensor<[1,24,4096,1],f32> -> !torch.vtensor<[1,24,4096,1],f32> %95 = torch.aten.mul.Tensor %89, %94 : !torch.vtensor<[1,24,4096,128],f32>, !torch.vtensor<[1,24,4096,1],f32> -> !torch.vtensor<[1,24,4096,128],f32> %int5_130 = torch.constant.int 5 %96 = torch.prims.convert_element_type %95, %int5_130 : !torch.vtensor<[1,24,4096,128],f32>, !torch.int -> !torch.vtensor<[1,24,4096,128],f16> %__auto.attn.img_attn.norm.key_norm.scale = util.global.load @__auto.attn.img_attn.norm.key_norm.scale : tensor<128xf16> %97 = torch_c.from_builtin_tensor %__auto.attn.img_attn.norm.key_norm.scale : tensor<128xf16> -> !torch.vtensor<[128],f16> %98 = torch.aten.mul.Tensor %96, %97 : !torch.vtensor<[1,24,4096,128],f16>, !torch.vtensor<[128],f16> -> !torch.vtensor<[1,24,4096,128],f16> %int5_131 = torch.constant.int 5 %99 = torch.prims.convert_element_type %83, %int5_131 : !torch.vtensor<[1,24,4096,128],f16>, !torch.int -> !torch.vtensor<[1,24,4096,128],f16> %int5_132 = torch.constant.int 5 %100 = torch.prims.convert_element_type %98, %int5_132 : !torch.vtensor<[1,24,4096,128],f16>, !torch.int -> !torch.vtensor<[1,24,4096,128],f16> %int6_133 = torch.constant.int 6 %101 = torch.prims.convert_element_type %arg1, %int6_133 : !torch.vtensor<[1,512,3072],f16>, !torch.int -> !torch.vtensor<[1,512,3072],f32> %int2_134 = torch.constant.int 2 %102 = torch.prim.ListConstruct %int2_134 : (!torch.int) -> !torch.list %int0_135 = torch.constant.int 0 %true_136 = torch.constant.bool true %result0_137, %result1_138 = torch.aten.var_mean.correction %101, %102, %int0_135, %true_136 : !torch.vtensor<[1,512,3072],f32>, !torch.list, !torch.int, !torch.bool -> !torch.vtensor<[1,512,1],f32>, !torch.vtensor<[1,512,1],f32> %float9.999990e-07_139 = torch.constant.float 9.9999999999999995E-7 %int1_140 = torch.constant.int 1 %103 = torch.aten.add.Scalar %result0_137, %float9.999990e-07_139, %int1_140 : !torch.vtensor<[1,512,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,512,1],f32> %104 = torch.aten.rsqrt %103 : !torch.vtensor<[1,512,1],f32> -> !torch.vtensor<[1,512,1],f32> %int1_141 = torch.constant.int 1 %105 = torch.aten.sub.Tensor %arg1, %result1_138, %int1_141 : !torch.vtensor<[1,512,3072],f16>, !torch.vtensor<[1,512,1],f32>, !torch.int -> !torch.vtensor<[1,512,3072],f32> %106 = torch.aten.mul.Tensor %105, %104 : !torch.vtensor<[1,512,3072],f32>, !torch.vtensor<[1,512,1],f32> -> !torch.vtensor<[1,512,3072],f32> %int5_142 = torch.constant.int 5 %107 = torch.prims.convert_element_type %106, %int5_142 : !torch.vtensor<[1,512,3072],f32>, !torch.int -> !torch.vtensor<[1,512,3072],f16> %int1_143 = torch.constant.int 1 %int1_144 = torch.constant.int 1 %108 = torch.aten.add.Scalar %37, %int1_143, %int1_144 : !torch.vtensor<[1,1,3072],f16>, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %109 = torch.aten.mul.Tensor %108, %107 : !torch.vtensor<[1,1,3072],f16>, !torch.vtensor<[1,512,3072],f16> -> !torch.vtensor<[1,512,3072],f16> %int1_145 = torch.constant.int 1 %110 = torch.aten.add.Tensor %109, %36, %int1_145 : !torch.vtensor<[1,512,3072],f16>, !torch.vtensor<[1,1,3072],f16>, !torch.int -> !torch.vtensor<[1,512,3072],f16> %int512 = torch.constant.int 512 %int3072_146 = torch.constant.int 3072 %111 = torch.prim.ListConstruct %int512, %int3072_146 : (!torch.int, !torch.int) -> !torch.list %112 = torch.aten.view %110, %111 : !torch.vtensor<[1,512,3072],f16>, !torch.list -> !torch.vtensor<[512,3072],f16> %__auto.attn.txt_attn.qkv.weight = util.global.load @__auto.attn.txt_attn.qkv.weight : tensor<9216x3072xf16> %113 = torch_c.from_builtin_tensor %__auto.attn.txt_attn.qkv.weight : tensor<9216x3072xf16> -> !torch.vtensor<[9216,3072],f16> %int0_147 = torch.constant.int 0 %int1_148 = torch.constant.int 1 %114 = torch.aten.transpose.int %113, %int0_147, %int1_148 : !torch.vtensor<[9216,3072],f16>, !torch.int, !torch.int -> !torch.vtensor<[3072,9216],f16> %__auto.attn.txt_attn.qkv.bias = util.global.load @__auto.attn.txt_attn.qkv.bias : tensor<9216xf16> %115 = torch_c.from_builtin_tensor %__auto.attn.txt_attn.qkv.bias : tensor<9216xf16> -> !torch.vtensor<[9216],f16> %int6_149 = torch.constant.int 6 %116 = torch.prims.convert_element_type %115, %int6_149 : !torch.vtensor<[9216],f16>, !torch.int -> !torch.vtensor<[9216],f32> %int6_150 = torch.constant.int 6 %117 = torch.prims.convert_element_type %112, %int6_150 : !torch.vtensor<[512,3072],f16>, !torch.int -> !torch.vtensor<[512,3072],f32> %int6_151 = torch.constant.int 6 %118 = torch.prims.convert_element_type %114, %int6_151 : !torch.vtensor<[3072,9216],f16>, !torch.int -> !torch.vtensor<[3072,9216],f32> %119 = torch.aten.mm %117, %118 : !torch.vtensor<[512,3072],f32>, !torch.vtensor<[3072,9216],f32> -> !torch.vtensor<[512,9216],f32> %int1_152 = torch.constant.int 1 %120 = torch.aten.mul.Scalar %119, %int1_152 : !torch.vtensor<[512,9216],f32>, !torch.int -> !torch.vtensor<[512,9216],f32> %int1_153 = torch.constant.int 1 %121 = torch.aten.mul.Scalar %116, %int1_153 : !torch.vtensor<[9216],f32>, !torch.int -> !torch.vtensor<[9216],f32> %int1_154 = torch.constant.int 1 %122 = torch.aten.add.Tensor %120, %121, %int1_154 : !torch.vtensor<[512,9216],f32>, !torch.vtensor<[9216],f32>, !torch.int -> !torch.vtensor<[512,9216],f32> %int5_155 = torch.constant.int 5 %123 = torch.prims.convert_element_type %122, %int5_155 : !torch.vtensor<[512,9216],f32>, !torch.int -> !torch.vtensor<[512,9216],f16> %int1_156 = torch.constant.int 1 %int512_157 = torch.constant.int 512 %int9216_158 = torch.constant.int 9216 %124 = torch.prim.ListConstruct %int1_156, %int512_157, %int9216_158 : (!torch.int, !torch.int, !torch.int) -> !torch.list %125 = torch.aten.view %123, %124 : !torch.vtensor<[512,9216],f16>, !torch.list -> !torch.vtensor<[1,512,9216],f16> %126 = torch_c.to_builtin_tensor %125 : !torch.vtensor<[1,512,9216],f16> -> tensor<1x512x9216xf16> %cast_159 = tensor.cast %126 : tensor<1x512x9216xf16> to tensor %c0_160 = arith.constant 0 : index %dim_161 = tensor.dim %cast_159, %c0_160 : tensor %c1_162 = arith.constant 1 : index %dim_163 = tensor.dim %cast_159, %c1_162 : tensor %c2_164 = arith.constant 2 : index %dim_165 = tensor.dim %cast_159, %c2_164 : tensor flow.tensor.trace "txt_qkv" = [%cast_159 : tensor{%dim_161, %dim_163, %dim_165}] %cast_166 = tensor.cast %cast_159 : tensor to tensor<1x512x9216xf16> %127 = torch_c.from_builtin_tensor %cast_166 : tensor<1x512x9216xf16> -> !torch.vtensor<[1,512,9216],f16> %int1_167 = torch.constant.int 1 %int512_168 = torch.constant.int 512 %int3_169 = torch.constant.int 3 %int24_170 = torch.constant.int 24 %int128_171 = torch.constant.int 128 %128 = torch.prim.ListConstruct %int1_167, %int512_168, %int3_169, %int24_170, %int128_171 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %129 = torch.aten.view %127, %128 : !torch.vtensor<[1,512,9216],f16>, !torch.list -> !torch.vtensor<[1,512,3,24,128],f16> %int2_172 = torch.constant.int 2 %int0_173 = torch.constant.int 0 %int3_174 = torch.constant.int 3 %int1_175 = torch.constant.int 1 %int4_176 = torch.constant.int 4 %130 = torch.prim.ListConstruct %int2_172, %int0_173, %int3_174, %int1_175, %int4_176 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %131 = torch.aten.permute %129, %130 : !torch.vtensor<[1,512,3,24,128],f16>, !torch.list -> !torch.vtensor<[3,1,24,512,128],f16> %int0_177 = torch.constant.int 0 %int0_178 = torch.constant.int 0 %132 = torch.aten.select.int %131, %int0_177, %int0_178 : !torch.vtensor<[3,1,24,512,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[1,24,512,128],f16> %int6_179 = torch.constant.int 6 %133 = torch.prims.convert_element_type %132, %int6_179 : !torch.vtensor<[1,24,512,128],f16>, !torch.int -> !torch.vtensor<[1,24,512,128],f32> %int2_180 = torch.constant.int 2 %134 = torch.aten.pow.Tensor_Scalar %133, %int2_180 : !torch.vtensor<[1,24,512,128],f32>, !torch.int -> !torch.vtensor<[1,24,512,128],f32> %int-1_181 = torch.constant.int -1 %135 = torch.prim.ListConstruct %int-1_181 : (!torch.int) -> !torch.list %true_182 = torch.constant.bool true %none_183 = torch.constant.none %136 = torch.aten.mean.dim %134, %135, %true_182, %none_183 : !torch.vtensor<[1,24,512,128],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,24,512,1],f32> %float9.999990e-07_184 = torch.constant.float 9.9999999999999995E-7 %int1_185 = torch.constant.int 1 %137 = torch.aten.add.Scalar %136, %float9.999990e-07_184, %int1_185 : !torch.vtensor<[1,24,512,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,24,512,1],f32> %138 = torch.aten.rsqrt %137 : !torch.vtensor<[1,24,512,1],f32> -> !torch.vtensor<[1,24,512,1],f32> %139 = torch.aten.mul.Tensor %133, %138 : !torch.vtensor<[1,24,512,128],f32>, !torch.vtensor<[1,24,512,1],f32> -> !torch.vtensor<[1,24,512,128],f32> %int5_186 = torch.constant.int 5 %140 = torch.prims.convert_element_type %139, %int5_186 : !torch.vtensor<[1,24,512,128],f32>, !torch.int -> !torch.vtensor<[1,24,512,128],f16> %__auto.attn.txt_attn.norm.query_norm.scale = util.global.load @__auto.attn.txt_attn.norm.query_norm.scale : tensor<128xf16> %141 = torch_c.from_builtin_tensor %__auto.attn.txt_attn.norm.query_norm.scale : tensor<128xf16> -> !torch.vtensor<[128],f16> %142 = torch.aten.mul.Tensor %140, %141 : !torch.vtensor<[1,24,512,128],f16>, !torch.vtensor<[128],f16> -> !torch.vtensor<[1,24,512,128],f16> %int1_187 = torch.constant.int 1 %int512_188 = torch.constant.int 512 %int3_189 = torch.constant.int 3 %int24_190 = torch.constant.int 24 %int128_191 = torch.constant.int 128 %143 = torch.prim.ListConstruct %int1_187, %int512_188, %int3_189, %int24_190, %int128_191 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %144 = torch.aten.view %127, %143 : !torch.vtensor<[1,512,9216],f16>, !torch.list -> !torch.vtensor<[1,512,3,24,128],f16> %int2_192 = torch.constant.int 2 %int0_193 = torch.constant.int 0 %int3_194 = torch.constant.int 3 %int1_195 = torch.constant.int 1 %int4_196 = torch.constant.int 4 %145 = torch.prim.ListConstruct %int2_192, %int0_193, %int3_194, %int1_195, %int4_196 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %146 = torch.aten.permute %144, %145 : !torch.vtensor<[1,512,3,24,128],f16>, !torch.list -> !torch.vtensor<[3,1,24,512,128],f16> %int0_197 = torch.constant.int 0 %int1_198 = torch.constant.int 1 %147 = torch.aten.select.int %146, %int0_197, %int1_198 : !torch.vtensor<[3,1,24,512,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[1,24,512,128],f16> %int6_199 = torch.constant.int 6 %148 = torch.prims.convert_element_type %147, %int6_199 : !torch.vtensor<[1,24,512,128],f16>, !torch.int -> !torch.vtensor<[1,24,512,128],f32> %int2_200 = torch.constant.int 2 %149 = torch.aten.pow.Tensor_Scalar %148, %int2_200 : !torch.vtensor<[1,24,512,128],f32>, !torch.int -> !torch.vtensor<[1,24,512,128],f32> %int-1_201 = torch.constant.int -1 %150 = torch.prim.ListConstruct %int-1_201 : (!torch.int) -> !torch.list %true_202 = torch.constant.bool true %none_203 = torch.constant.none %151 = torch.aten.mean.dim %149, %150, %true_202, %none_203 : !torch.vtensor<[1,24,512,128],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,24,512,1],f32> %float9.999990e-07_204 = torch.constant.float 9.9999999999999995E-7 %int1_205 = torch.constant.int 1 %152 = torch.aten.add.Scalar %151, %float9.999990e-07_204, %int1_205 : !torch.vtensor<[1,24,512,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,24,512,1],f32> %153 = torch.aten.rsqrt %152 : !torch.vtensor<[1,24,512,1],f32> -> !torch.vtensor<[1,24,512,1],f32> %154 = torch.aten.mul.Tensor %148, %153 : !torch.vtensor<[1,24,512,128],f32>, !torch.vtensor<[1,24,512,1],f32> -> !torch.vtensor<[1,24,512,128],f32> %int5_206 = torch.constant.int 5 %155 = torch.prims.convert_element_type %154, %int5_206 : !torch.vtensor<[1,24,512,128],f32>, !torch.int -> !torch.vtensor<[1,24,512,128],f16> %__auto.attn.txt_attn.norm.key_norm.scale = util.global.load @__auto.attn.txt_attn.norm.key_norm.scale : tensor<128xf16> %156 = torch_c.from_builtin_tensor %__auto.attn.txt_attn.norm.key_norm.scale : tensor<128xf16> -> !torch.vtensor<[128],f16> %157 = torch.aten.mul.Tensor %155, %156 : !torch.vtensor<[1,24,512,128],f16>, !torch.vtensor<[128],f16> -> !torch.vtensor<[1,24,512,128],f16> %int5_207 = torch.constant.int 5 %158 = torch.prims.convert_element_type %142, %int5_207 : !torch.vtensor<[1,24,512,128],f16>, !torch.int -> !torch.vtensor<[1,24,512,128],f16> %int5_208 = torch.constant.int 5 %159 = torch.prims.convert_element_type %157, %int5_208 : !torch.vtensor<[1,24,512,128],f16>, !torch.int -> !torch.vtensor<[1,24,512,128],f16> %160 = torch.prim.ListConstruct %158, %99 : (!torch.vtensor<[1,24,512,128],f16>, !torch.vtensor<[1,24,4096,128],f16>) -> !torch.list %int2_209 = torch.constant.int 2 %161 = torch.aten.cat %160, %int2_209 : !torch.list, !torch.int -> !torch.vtensor<[1,24,4608,128],f16> %162 = torch.prim.ListConstruct %159, %100 : (!torch.vtensor<[1,24,512,128],f16>, !torch.vtensor<[1,24,4096,128],f16>) -> !torch.list %int2_210 = torch.constant.int 2 %163 = torch.aten.cat %162, %int2_210 : !torch.list, !torch.int -> !torch.vtensor<[1,24,4608,128],f16> %int1_211 = torch.constant.int 1 %int512_212 = torch.constant.int 512 %int3_213 = torch.constant.int 3 %int24_214 = torch.constant.int 24 %int128_215 = torch.constant.int 128 %164 = torch.prim.ListConstruct %int1_211, %int512_212, %int3_213, %int24_214, %int128_215 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %165 = torch.aten.view %127, %164 : !torch.vtensor<[1,512,9216],f16>, !torch.list -> !torch.vtensor<[1,512,3,24,128],f16> %int2_216 = torch.constant.int 2 %int0_217 = torch.constant.int 0 %int3_218 = torch.constant.int 3 %int1_219 = torch.constant.int 1 %int4_220 = torch.constant.int 4 %166 = torch.prim.ListConstruct %int2_216, %int0_217, %int3_218, %int1_219, %int4_220 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %167 = torch.aten.permute %165, %166 : !torch.vtensor<[1,512,3,24,128],f16>, !torch.list -> !torch.vtensor<[3,1,24,512,128],f16> %int0_221 = torch.constant.int 0 %int2_222 = torch.constant.int 2 %168 = torch.aten.select.int %167, %int0_221, %int2_222 : !torch.vtensor<[3,1,24,512,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[1,24,512,128],f16> %int1_223 = torch.constant.int 1 %int4096_224 = torch.constant.int 4096 %int3_225 = torch.constant.int 3 %int24_226 = torch.constant.int 24 %int128_227 = torch.constant.int 128 %169 = torch.prim.ListConstruct %int1_223, %int4096_224, %int3_225, %int24_226, %int128_227 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %170 = torch.aten.view %68, %169 : !torch.vtensor<[1,4096,9216],f16>, !torch.list -> !torch.vtensor<[1,4096,3,24,128],f16> %int2_228 = torch.constant.int 2 %int0_229 = torch.constant.int 0 %int3_230 = torch.constant.int 3 %int1_231 = torch.constant.int 1 %int4_232 = torch.constant.int 4 %171 = torch.prim.ListConstruct %int2_228, %int0_229, %int3_230, %int1_231, %int4_232 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %172 = torch.aten.permute %170, %171 : !torch.vtensor<[1,4096,3,24,128],f16>, !torch.list -> !torch.vtensor<[3,1,24,4096,128],f16> %int0_233 = torch.constant.int 0 %int2_234 = torch.constant.int 2 %173 = torch.aten.select.int %172, %int0_233, %int2_234 : !torch.vtensor<[3,1,24,4096,128],f16>, !torch.int, !torch.int -> !torch.vtensor<[1,24,4096,128],f16> %174 = torch.prim.ListConstruct %168, %173 : (!torch.vtensor<[1,24,512,128],f16>, !torch.vtensor<[1,24,4096,128],f16>) -> !torch.list %int2_235 = torch.constant.int 2 %175 = torch.aten.cat %174, %int2_235 : !torch.list, !torch.int -> !torch.vtensor<[1,24,4608,128],f16> %176 = torch_c.to_builtin_tensor %161 : !torch.vtensor<[1,24,4608,128],f16> -> tensor<1x24x4608x128xf16> %cast_236 = tensor.cast %176 : tensor<1x24x4608x128xf16> to tensor %c0_237 = arith.constant 0 : index %dim_238 = tensor.dim %cast_236, %c0_237 : tensor %c1_239 = arith.constant 1 : index %dim_240 = tensor.dim %cast_236, %c1_239 : tensor %c2_241 = arith.constant 2 : index %dim_242 = tensor.dim %cast_236, %c2_241 : tensor %c3 = arith.constant 3 : index %dim_243 = tensor.dim %cast_236, %c3 : tensor flow.tensor.trace "q" = [%cast_236 : tensor{%dim_238, %dim_240, %dim_242, %dim_243}] %cast_244 = tensor.cast %cast_236 : tensor to tensor<1x24x4608x128xf16> %177 = torch_c.from_builtin_tensor %cast_244 : tensor<1x24x4608x128xf16> -> !torch.vtensor<[1,24,4608,128],f16> %178 = torch_c.to_builtin_tensor %163 : !torch.vtensor<[1,24,4608,128],f16> -> tensor<1x24x4608x128xf16> %cast_245 = tensor.cast %178 : tensor<1x24x4608x128xf16> to tensor %c0_246 = arith.constant 0 : index %dim_247 = tensor.dim %cast_245, %c0_246 : tensor %c1_248 = arith.constant 1 : index %dim_249 = tensor.dim %cast_245, %c1_248 : tensor %c2_250 = arith.constant 2 : index %dim_251 = tensor.dim %cast_245, %c2_250 : tensor %c3_252 = arith.constant 3 : index %dim_253 = tensor.dim %cast_245, %c3_252 : tensor flow.tensor.trace "k" = [%cast_245 : tensor{%dim_247, %dim_249, %dim_251, %dim_253}] %cast_254 = tensor.cast %cast_245 : tensor to tensor<1x24x4608x128xf16> %179 = torch_c.from_builtin_tensor %cast_254 : tensor<1x24x4608x128xf16> -> !torch.vtensor<[1,24,4608,128],f16> %180 = torch_c.to_builtin_tensor %175 : !torch.vtensor<[1,24,4608,128],f16> -> tensor<1x24x4608x128xf16> %cast_255 = tensor.cast %180 : tensor<1x24x4608x128xf16> to tensor %c0_256 = arith.constant 0 : index %dim_257 = tensor.dim %cast_255, %c0_256 : tensor %c1_258 = arith.constant 1 : index %dim_259 = tensor.dim %cast_255, %c1_258 : tensor %c2_260 = arith.constant 2 : index %dim_261 = tensor.dim %cast_255, %c2_260 : tensor %c3_262 = arith.constant 3 : index %dim_263 = tensor.dim %cast_255, %c3_262 : tensor flow.tensor.trace "v" = [%cast_255 : tensor{%dim_257, %dim_259, %dim_261, %dim_263}] %cast_264 = tensor.cast %cast_255 : tensor to tensor<1x24x4608x128xf16> %181 = torch_c.from_builtin_tensor %cast_264 : tensor<1x24x4608x128xf16> -> !torch.vtensor<[1,24,4608,128],f16> %int6_265 = torch.constant.int 6 %182 = torch.prims.convert_element_type %177, %int6_265 : !torch.vtensor<[1,24,4608,128],f16>, !torch.int -> !torch.vtensor<[1,24,4608,128],f32> %int1_266 = torch.constant.int 1 %int24_267 = torch.constant.int 24 %int4608 = torch.constant.int 4608 %int-1_268 = torch.constant.int -1 %int1_269 = torch.constant.int 1 %int2_270 = torch.constant.int 2 %183 = torch.prim.ListConstruct %int1_266, %int24_267, %int4608, %int-1_268, %int1_269, %int2_270 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %184 = torch.aten.view %182, %183 : !torch.vtensor<[1,24,4608,128],f32>, !torch.list -> !torch.vtensor<[1,24,4608,64,1,2],f32> %int6_271 = torch.constant.int 6 %185 = torch.prims.convert_element_type %179, %int6_271 : !torch.vtensor<[1,24,4608,128],f16>, !torch.int -> !torch.vtensor<[1,24,4608,128],f32> %int1_272 = torch.constant.int 1 %int24_273 = torch.constant.int 24 %int4608_274 = torch.constant.int 4608 %int-1_275 = torch.constant.int -1 %int1_276 = torch.constant.int 1 %int2_277 = torch.constant.int 2 %186 = torch.prim.ListConstruct %int1_272, %int24_273, %int4608_274, %int-1_275, %int1_276, %int2_277 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %187 = torch.aten.view %185, %186 : !torch.vtensor<[1,24,4608,128],f32>, !torch.list -> !torch.vtensor<[1,24,4608,64,1,2],f32> %int5_278 = torch.constant.int 5 %int0_279 = torch.constant.int 0 %188 = torch.aten.select.int %arg3, %int5_278, %int0_279 : !torch.vtensor<[1,1,4608,64,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,4608,64,2],f32> %int5_280 = torch.constant.int 5 %int0_281 = torch.constant.int 0 %189 = torch.aten.select.int %184, %int5_280, %int0_281 : !torch.vtensor<[1,24,4608,64,1,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,24,4608,64,1],f32> %190 = torch.aten.mul.Tensor %188, %189 : !torch.vtensor<[1,1,4608,64,2],f32>, !torch.vtensor<[1,24,4608,64,1],f32> -> !torch.vtensor<[1,24,4608,64,2],f32> %int5_282 = torch.constant.int 5 %int1_283 = torch.constant.int 1 %191 = torch.aten.select.int %arg3, %int5_282, %int1_283 : !torch.vtensor<[1,1,4608,64,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,4608,64,2],f32> %int5_284 = torch.constant.int 5 %int1_285 = torch.constant.int 1 %192 = torch.aten.select.int %184, %int5_284, %int1_285 : !torch.vtensor<[1,24,4608,64,1,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,24,4608,64,1],f32> %193 = torch.aten.mul.Tensor %191, %192 : !torch.vtensor<[1,1,4608,64,2],f32>, !torch.vtensor<[1,24,4608,64,1],f32> -> !torch.vtensor<[1,24,4608,64,2],f32> %int1_286 = torch.constant.int 1 %194 = torch.aten.add.Tensor %190, %193, %int1_286 : !torch.vtensor<[1,24,4608,64,2],f32>, !torch.vtensor<[1,24,4608,64,2],f32>, !torch.int -> !torch.vtensor<[1,24,4608,64,2],f32> %int5_287 = torch.constant.int 5 %int0_288 = torch.constant.int 0 %195 = torch.aten.select.int %arg3, %int5_287, %int0_288 : !torch.vtensor<[1,1,4608,64,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,4608,64,2],f32> %int5_289 = torch.constant.int 5 %int0_290 = torch.constant.int 0 %196 = torch.aten.select.int %187, %int5_289, %int0_290 : !torch.vtensor<[1,24,4608,64,1,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,24,4608,64,1],f32> %197 = torch.aten.mul.Tensor %195, %196 : !torch.vtensor<[1,1,4608,64,2],f32>, !torch.vtensor<[1,24,4608,64,1],f32> -> !torch.vtensor<[1,24,4608,64,2],f32> %int5_291 = torch.constant.int 5 %int1_292 = torch.constant.int 1 %198 = torch.aten.select.int %arg3, %int5_291, %int1_292 : !torch.vtensor<[1,1,4608,64,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,4608,64,2],f32> %int5_293 = torch.constant.int 5 %int1_294 = torch.constant.int 1 %199 = torch.aten.select.int %187, %int5_293, %int1_294 : !torch.vtensor<[1,24,4608,64,1,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,24,4608,64,1],f32> %200 = torch.aten.mul.Tensor %198, %199 : !torch.vtensor<[1,1,4608,64,2],f32>, !torch.vtensor<[1,24,4608,64,1],f32> -> !torch.vtensor<[1,24,4608,64,2],f32> %int1_295 = torch.constant.int 1 %201 = torch.aten.add.Tensor %197, %200, %int1_295 : !torch.vtensor<[1,24,4608,64,2],f32>, !torch.vtensor<[1,24,4608,64,2],f32>, !torch.int -> !torch.vtensor<[1,24,4608,64,2],f32> %int1_296 = torch.constant.int 1 %int24_297 = torch.constant.int 24 %int4608_298 = torch.constant.int 4608 %int128_299 = torch.constant.int 128 %202 = torch.prim.ListConstruct %int1_296, %int24_297, %int4608_298, %int128_299 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %203 = torch.aten.view %194, %202 : !torch.vtensor<[1,24,4608,64,2],f32>, !torch.list -> !torch.vtensor<[1,24,4608,128],f32> %int5_300 = torch.constant.int 5 %204 = torch.prims.convert_element_type %203, %int5_300 : !torch.vtensor<[1,24,4608,128],f32>, !torch.int -> !torch.vtensor<[1,24,4608,128],f16> %int1_301 = torch.constant.int 1 %int24_302 = torch.constant.int 24 %int4608_303 = torch.constant.int 4608 %int128_304 = torch.constant.int 128 %205 = torch.prim.ListConstruct %int1_301, %int24_302, %int4608_303, %int128_304 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %206 = torch.aten.view %201, %205 : !torch.vtensor<[1,24,4608,64,2],f32>, !torch.list -> !torch.vtensor<[1,24,4608,128],f32> %int5_305 = torch.constant.int 5 %207 = torch.prims.convert_element_type %206, %int5_305 : !torch.vtensor<[1,24,4608,128],f32>, !torch.int -> !torch.vtensor<[1,24,4608,128],f16> %float0.000000e00 = torch.constant.float 0.000000e+00 %false = torch.constant.bool false %none_306 = torch.constant.none %none_307 = torch.constant.none %208:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%204, %207, %181, %float0.000000e00, %false, %none_306, %none_307) : (!torch.vtensor<[1,24,4608,128],f16>, !torch.vtensor<[1,24,4608,128],f16>, !torch.vtensor<[1,24,4608,128],f16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,24,4608,128],f16>, !torch.vtensor<[1,24,4608],f32>) %int0_308 = torch.constant.int 0 %int2_309 = torch.constant.int 2 %int1_310 = torch.constant.int 1 %int3_311 = torch.constant.int 3 %209 = torch.prim.ListConstruct %int0_308, %int2_309, %int1_310, %int3_311 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list %210 = torch.aten.permute %208#0, %209 : !torch.vtensor<[1,24,4608,128],f16>, !torch.list -> !torch.vtensor<[1,4608,24,128],f16> %int1_312 = torch.constant.int 1 %int4608_313 = torch.constant.int 4608 %int3072_314 = torch.constant.int 3072 %211 = torch.prim.ListConstruct %int1_312, %int4608_313, %int3072_314 : (!torch.int, !torch.int, !torch.int) -> !torch.list %212 = torch.aten.view %210, %211 : !torch.vtensor<[1,4608,24,128],f16>, !torch.list -> !torch.vtensor<[1,4608,3072],f16> %int0_315 = torch.constant.int 0 %int0_316 = torch.constant.int 0 %int9223372036854775807_317 = torch.constant.int 9223372036854775807 %int1_318 = torch.constant.int 1 %213 = torch.aten.slice.Tensor %212, %int0_315, %int0_316, %int9223372036854775807_317, %int1_318 : !torch.vtensor<[1,4608,3072],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4608,3072],f16> %int1_319 = torch.constant.int 1 %int0_320 = torch.constant.int 0 %int512_321 = torch.constant.int 512 %int1_322 = torch.constant.int 1 %214 = torch.aten.slice.Tensor %213, %int1_319, %int0_320, %int512_321, %int1_322 : !torch.vtensor<[1,4608,3072],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,512,3072],f16> %int0_323 = torch.constant.int 0 %int0_324 = torch.constant.int 0 %int9223372036854775807_325 = torch.constant.int 9223372036854775807 %int1_326 = torch.constant.int 1 %215 = torch.aten.slice.Tensor %212, %int0_323, %int0_324, %int9223372036854775807_325, %int1_326 : !torch.vtensor<[1,4608,3072],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4608,3072],f16> %int1_327 = torch.constant.int 1 %int512_328 = torch.constant.int 512 %int9223372036854775807_329 = torch.constant.int 9223372036854775807 %int1_330 = torch.constant.int 1 %216 = torch.aten.slice.Tensor %215, %int1_327, %int512_328, %int9223372036854775807_329, %int1_330 : !torch.vtensor<[1,4608,3072],f16>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4096,3072],f16> %int4096_331 = torch.constant.int 4096 %int3072_332 = torch.constant.int 3072 %217 = torch.prim.ListConstruct %int4096_331, %int3072_332 : (!torch.int, !torch.int) -> !torch.list %218 = torch.aten.view %216, %217 : !torch.vtensor<[1,4096,3072],f16>, !torch.list -> !torch.vtensor<[4096,3072],f16> %__auto.attn.img_attn.proj.weight = util.global.load @__auto.attn.img_attn.proj.weight : tensor<3072x3072xf16> %219 = torch_c.from_builtin_tensor %__auto.attn.img_attn.proj.weight : tensor<3072x3072xf16> -> !torch.vtensor<[3072,3072],f16> %int0_333 = torch.constant.int 0 %int1_334 = torch.constant.int 1 %220 = torch.aten.transpose.int %219, %int0_333, %int1_334 : !torch.vtensor<[3072,3072],f16>, !torch.int, !torch.int -> !torch.vtensor<[3072,3072],f16> %__auto.attn.img_attn.proj.bias = util.global.load @__auto.attn.img_attn.proj.bias : tensor<3072xf16> %221 = torch_c.from_builtin_tensor %__auto.attn.img_attn.proj.bias : tensor<3072xf16> -> !torch.vtensor<[3072],f16> %int6_335 = torch.constant.int 6 %222 = torch.prims.convert_element_type %221, %int6_335 : !torch.vtensor<[3072],f16>, !torch.int -> !torch.vtensor<[3072],f32> %int6_336 = torch.constant.int 6 %223 = torch.prims.convert_element_type %218, %int6_336 : !torch.vtensor<[4096,3072],f16>, !torch.int -> !torch.vtensor<[4096,3072],f32> %int6_337 = torch.constant.int 6 %224 = torch.prims.convert_element_type %220, %int6_337 : !torch.vtensor<[3072,3072],f16>, !torch.int -> !torch.vtensor<[3072,3072],f32> %225 = torch.aten.mm %223, %224 : !torch.vtensor<[4096,3072],f32>, !torch.vtensor<[3072,3072],f32> -> !torch.vtensor<[4096,3072],f32> %int1_338 = torch.constant.int 1 %226 = torch.aten.mul.Scalar %225, %int1_338 : !torch.vtensor<[4096,3072],f32>, !torch.int -> !torch.vtensor<[4096,3072],f32> %int1_339 = torch.constant.int 1 %227 = torch.aten.mul.Scalar %222, %int1_339 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32> %int1_340 = torch.constant.int 1 %228 = torch.aten.add.Tensor %226, %227, %int1_340 : !torch.vtensor<[4096,3072],f32>, !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[4096,3072],f32> %int5_341 = torch.constant.int 5 %229 = torch.prims.convert_element_type %228, %int5_341 : !torch.vtensor<[4096,3072],f32>, !torch.int -> !torch.vtensor<[4096,3072],f16> %int1_342 = torch.constant.int 1 %int4096_343 = torch.constant.int 4096 %int3072_344 = torch.constant.int 3072 %230 = torch.prim.ListConstruct %int1_342, %int4096_343, %int3072_344 : (!torch.int, !torch.int, !torch.int) -> !torch.list %231 = torch.aten.view %229, %230 : !torch.vtensor<[4096,3072],f16>, !torch.list -> !torch.vtensor<[1,4096,3072],f16> %232 = torch.aten.mul.Tensor %17, %231 : !torch.vtensor<[1,1,3072],f16>, !torch.vtensor<[1,4096,3072],f16> -> !torch.vtensor<[1,4096,3072],f16> %int1_345 = torch.constant.int 1 %233 = torch.aten.add.Tensor %arg0, %232, %int1_345 : !torch.vtensor<[1,4096,3072],f16>, !torch.vtensor<[1,4096,3072],f16>, !torch.int -> !torch.vtensor<[1,4096,3072],f16> %int1_346 = torch.constant.int 1 %int1_347 = torch.constant.int 1 %234 = torch.aten.add.Scalar %19, %int1_346, %int1_347 : !torch.vtensor<[1,1,3072],f16>, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %int6_348 = torch.constant.int 6 %235 = torch.prims.convert_element_type %233, %int6_348 : !torch.vtensor<[1,4096,3072],f16>, !torch.int -> !torch.vtensor<[1,4096,3072],f32> %int2_349 = torch.constant.int 2 %236 = torch.prim.ListConstruct %int2_349 : (!torch.int) -> !torch.list %int0_350 = torch.constant.int 0 %true_351 = torch.constant.bool true %result0_352, %result1_353 = torch.aten.var_mean.correction %235, %236, %int0_350, %true_351 : !torch.vtensor<[1,4096,3072],f32>, !torch.list, !torch.int, !torch.bool -> !torch.vtensor<[1,4096,1],f32>, !torch.vtensor<[1,4096,1],f32> %float9.999990e-07_354 = torch.constant.float 9.9999999999999995E-7 %int1_355 = torch.constant.int 1 %237 = torch.aten.add.Scalar %result0_352, %float9.999990e-07_354, %int1_355 : !torch.vtensor<[1,4096,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,4096,1],f32> %238 = torch.aten.rsqrt %237 : !torch.vtensor<[1,4096,1],f32> -> !torch.vtensor<[1,4096,1],f32> %int1_356 = torch.constant.int 1 %239 = torch.aten.sub.Tensor %233, %result1_353, %int1_356 : !torch.vtensor<[1,4096,3072],f16>, !torch.vtensor<[1,4096,1],f32>, !torch.int -> !torch.vtensor<[1,4096,3072],f32> %240 = torch.aten.mul.Tensor %239, %238 : !torch.vtensor<[1,4096,3072],f32>, !torch.vtensor<[1,4096,1],f32> -> !torch.vtensor<[1,4096,3072],f32> %int5_357 = torch.constant.int 5 %241 = torch.prims.convert_element_type %240, %int5_357 : !torch.vtensor<[1,4096,3072],f32>, !torch.int -> !torch.vtensor<[1,4096,3072],f16> %242 = torch.aten.mul.Tensor %234, %241 : !torch.vtensor<[1,1,3072],f16>, !torch.vtensor<[1,4096,3072],f16> -> !torch.vtensor<[1,4096,3072],f16> %int1_358 = torch.constant.int 1 %243 = torch.aten.add.Tensor %242, %18, %int1_358 : !torch.vtensor<[1,4096,3072],f16>, !torch.vtensor<[1,1,3072],f16>, !torch.int -> !torch.vtensor<[1,4096,3072],f16> %int4096_359 = torch.constant.int 4096 %int3072_360 = torch.constant.int 3072 %244 = torch.prim.ListConstruct %int4096_359, %int3072_360 : (!torch.int, !torch.int) -> !torch.list %245 = torch.aten.view %243, %244 : !torch.vtensor<[1,4096,3072],f16>, !torch.list -> !torch.vtensor<[4096,3072],f16> %__auto.attn.img_mlp.0.weight = util.global.load @__auto.attn.img_mlp.0.weight : tensor<12288x3072xf16> %246 = torch_c.from_builtin_tensor %__auto.attn.img_mlp.0.weight : tensor<12288x3072xf16> -> !torch.vtensor<[12288,3072],f16> %int0_361 = torch.constant.int 0 %int1_362 = torch.constant.int 1 %247 = torch.aten.transpose.int %246, %int0_361, %int1_362 : !torch.vtensor<[12288,3072],f16>, !torch.int, !torch.int -> !torch.vtensor<[3072,12288],f16> %__auto.attn.img_mlp.0.bias = util.global.load @__auto.attn.img_mlp.0.bias : tensor<12288xf16> %248 = torch_c.from_builtin_tensor %__auto.attn.img_mlp.0.bias : tensor<12288xf16> -> !torch.vtensor<[12288],f16> %int6_363 = torch.constant.int 6 %249 = torch.prims.convert_element_type %248, %int6_363 : !torch.vtensor<[12288],f16>, !torch.int -> !torch.vtensor<[12288],f32> %int6_364 = torch.constant.int 6 %250 = torch.prims.convert_element_type %245, %int6_364 : !torch.vtensor<[4096,3072],f16>, !torch.int -> !torch.vtensor<[4096,3072],f32> %int6_365 = torch.constant.int 6 %251 = torch.prims.convert_element_type %247, %int6_365 : !torch.vtensor<[3072,12288],f16>, !torch.int -> !torch.vtensor<[3072,12288],f32> %252 = torch.aten.mm %250, %251 : !torch.vtensor<[4096,3072],f32>, !torch.vtensor<[3072,12288],f32> -> !torch.vtensor<[4096,12288],f32> %int1_366 = torch.constant.int 1 %253 = torch.aten.mul.Scalar %252, %int1_366 : !torch.vtensor<[4096,12288],f32>, !torch.int -> !torch.vtensor<[4096,12288],f32> %int1_367 = torch.constant.int 1 %254 = torch.aten.mul.Scalar %249, %int1_367 : !torch.vtensor<[12288],f32>, !torch.int -> !torch.vtensor<[12288],f32> %int1_368 = torch.constant.int 1 %255 = torch.aten.add.Tensor %253, %254, %int1_368 : !torch.vtensor<[4096,12288],f32>, !torch.vtensor<[12288],f32>, !torch.int -> !torch.vtensor<[4096,12288],f32> %int5_369 = torch.constant.int 5 %256 = torch.prims.convert_element_type %255, %int5_369 : !torch.vtensor<[4096,12288],f32>, !torch.int -> !torch.vtensor<[4096,12288],f16> %int1_370 = torch.constant.int 1 %int4096_371 = torch.constant.int 4096 %int12288_372 = torch.constant.int 12288 %257 = torch.prim.ListConstruct %int1_370, %int4096_371, %int12288_372 : (!torch.int, !torch.int, !torch.int) -> !torch.list %258 = torch.aten.view %256, %257 : !torch.vtensor<[4096,12288],f16>, !torch.list -> !torch.vtensor<[1,4096,12288],f16> %str = torch.constant.str "tanh" %259 = torch.aten.gelu %258, %str : !torch.vtensor<[1,4096,12288],f16>, !torch.str -> !torch.vtensor<[1,4096,12288],f16> %int4096_373 = torch.constant.int 4096 %int12288_374 = torch.constant.int 12288 %260 = torch.prim.ListConstruct %int4096_373, %int12288_374 : (!torch.int, !torch.int) -> !torch.list %261 = torch.aten.view %259, %260 : !torch.vtensor<[1,4096,12288],f16>, !torch.list -> !torch.vtensor<[4096,12288],f16> %__auto.attn.img_mlp.2.weight = util.global.load @__auto.attn.img_mlp.2.weight : tensor<3072x12288xf16> %262 = torch_c.from_builtin_tensor %__auto.attn.img_mlp.2.weight : tensor<3072x12288xf16> -> !torch.vtensor<[3072,12288],f16> %int0_375 = torch.constant.int 0 %int1_376 = torch.constant.int 1 %263 = torch.aten.transpose.int %262, %int0_375, %int1_376 : !torch.vtensor<[3072,12288],f16>, !torch.int, !torch.int -> !torch.vtensor<[12288,3072],f16> %__auto.attn.img_mlp.2.bias = util.global.load @__auto.attn.img_mlp.2.bias : tensor<3072xf16> %264 = torch_c.from_builtin_tensor %__auto.attn.img_mlp.2.bias : tensor<3072xf16> -> !torch.vtensor<[3072],f16> %int6_377 = torch.constant.int 6 %265 = torch.prims.convert_element_type %264, %int6_377 : !torch.vtensor<[3072],f16>, !torch.int -> !torch.vtensor<[3072],f32> %int6_378 = torch.constant.int 6 %266 = torch.prims.convert_element_type %261, %int6_378 : !torch.vtensor<[4096,12288],f16>, !torch.int -> !torch.vtensor<[4096,12288],f32> %int6_379 = torch.constant.int 6 %267 = torch.prims.convert_element_type %263, %int6_379 : !torch.vtensor<[12288,3072],f16>, !torch.int -> !torch.vtensor<[12288,3072],f32> %268 = torch.aten.mm %266, %267 : !torch.vtensor<[4096,12288],f32>, !torch.vtensor<[12288,3072],f32> -> !torch.vtensor<[4096,3072],f32> %int1_380 = torch.constant.int 1 %269 = torch.aten.mul.Scalar %268, %int1_380 : !torch.vtensor<[4096,3072],f32>, !torch.int -> !torch.vtensor<[4096,3072],f32> %int1_381 = torch.constant.int 1 %270 = torch.aten.mul.Scalar %265, %int1_381 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32> %int1_382 = torch.constant.int 1 %271 = torch.aten.add.Tensor %269, %270, %int1_382 : !torch.vtensor<[4096,3072],f32>, !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[4096,3072],f32> %int5_383 = torch.constant.int 5 %272 = torch.prims.convert_element_type %271, %int5_383 : !torch.vtensor<[4096,3072],f32>, !torch.int -> !torch.vtensor<[4096,3072],f16> %int1_384 = torch.constant.int 1 %int4096_385 = torch.constant.int 4096 %int3072_386 = torch.constant.int 3072 %273 = torch.prim.ListConstruct %int1_384, %int4096_385, %int3072_386 : (!torch.int, !torch.int, !torch.int) -> !torch.list %274 = torch.aten.view %272, %273 : !torch.vtensor<[4096,3072],f16>, !torch.list -> !torch.vtensor<[1,4096,3072],f16> %275 = torch.aten.mul.Tensor %20, %274 : !torch.vtensor<[1,1,3072],f16>, !torch.vtensor<[1,4096,3072],f16> -> !torch.vtensor<[1,4096,3072],f16> %int1_387 = torch.constant.int 1 %276 = torch.aten.add.Tensor %233, %275, %int1_387 : !torch.vtensor<[1,4096,3072],f16>, !torch.vtensor<[1,4096,3072],f16>, !torch.int -> !torch.vtensor<[1,4096,3072],f16> %int512_388 = torch.constant.int 512 %int3072_389 = torch.constant.int 3072 %277 = torch.prim.ListConstruct %int512_388, %int3072_389 : (!torch.int, !torch.int) -> !torch.list %278 = torch.aten.view %214, %277 : !torch.vtensor<[1,512,3072],f16>, !torch.list -> !torch.vtensor<[512,3072],f16> %__auto.attn.txt_attn.proj.weight = util.global.load @__auto.attn.txt_attn.proj.weight : tensor<3072x3072xf16> %279 = torch_c.from_builtin_tensor %__auto.attn.txt_attn.proj.weight : tensor<3072x3072xf16> -> !torch.vtensor<[3072,3072],f16> %int0_390 = torch.constant.int 0 %int1_391 = torch.constant.int 1 %280 = torch.aten.transpose.int %279, %int0_390, %int1_391 : !torch.vtensor<[3072,3072],f16>, !torch.int, !torch.int -> !torch.vtensor<[3072,3072],f16> %__auto.attn.txt_attn.proj.bias = util.global.load @__auto.attn.txt_attn.proj.bias : tensor<3072xf16> %281 = torch_c.from_builtin_tensor %__auto.attn.txt_attn.proj.bias : tensor<3072xf16> -> !torch.vtensor<[3072],f16> %int6_392 = torch.constant.int 6 %282 = torch.prims.convert_element_type %281, %int6_392 : !torch.vtensor<[3072],f16>, !torch.int -> !torch.vtensor<[3072],f32> %int6_393 = torch.constant.int 6 %283 = torch.prims.convert_element_type %278, %int6_393 : !torch.vtensor<[512,3072],f16>, !torch.int -> !torch.vtensor<[512,3072],f32> %int6_394 = torch.constant.int 6 %284 = torch.prims.convert_element_type %280, %int6_394 : !torch.vtensor<[3072,3072],f16>, !torch.int -> !torch.vtensor<[3072,3072],f32> %285 = torch.aten.mm %283, %284 : !torch.vtensor<[512,3072],f32>, !torch.vtensor<[3072,3072],f32> -> !torch.vtensor<[512,3072],f32> %int1_395 = torch.constant.int 1 %286 = torch.aten.mul.Scalar %285, %int1_395 : !torch.vtensor<[512,3072],f32>, !torch.int -> !torch.vtensor<[512,3072],f32> %int1_396 = torch.constant.int 1 %287 = torch.aten.mul.Scalar %282, %int1_396 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32> %int1_397 = torch.constant.int 1 %288 = torch.aten.add.Tensor %286, %287, %int1_397 : !torch.vtensor<[512,3072],f32>, !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[512,3072],f32> %int5_398 = torch.constant.int 5 %289 = torch.prims.convert_element_type %288, %int5_398 : !torch.vtensor<[512,3072],f32>, !torch.int -> !torch.vtensor<[512,3072],f16> %int1_399 = torch.constant.int 1 %int512_400 = torch.constant.int 512 %int3072_401 = torch.constant.int 3072 %290 = torch.prim.ListConstruct %int1_399, %int512_400, %int3072_401 : (!torch.int, !torch.int, !torch.int) -> !torch.list %291 = torch.aten.view %289, %290 : !torch.vtensor<[512,3072],f16>, !torch.list -> !torch.vtensor<[1,512,3072],f16> %292 = torch.aten.mul.Tensor %38, %291 : !torch.vtensor<[1,1,3072],f16>, !torch.vtensor<[1,512,3072],f16> -> !torch.vtensor<[1,512,3072],f16> %int1_402 = torch.constant.int 1 %293 = torch.aten.add.Tensor %arg1, %292, %int1_402 : !torch.vtensor<[1,512,3072],f16>, !torch.vtensor<[1,512,3072],f16>, !torch.int -> !torch.vtensor<[1,512,3072],f16> %int1_403 = torch.constant.int 1 %int1_404 = torch.constant.int 1 %294 = torch.aten.add.Scalar %40, %int1_403, %int1_404 : !torch.vtensor<[1,1,3072],f16>, !torch.int, !torch.int -> !torch.vtensor<[1,1,3072],f16> %int6_405 = torch.constant.int 6 %295 = torch.prims.convert_element_type %293, %int6_405 : !torch.vtensor<[1,512,3072],f16>, !torch.int -> !torch.vtensor<[1,512,3072],f32> %int2_406 = torch.constant.int 2 %296 = torch.prim.ListConstruct %int2_406 : (!torch.int) -> !torch.list %int0_407 = torch.constant.int 0 %true_408 = torch.constant.bool true %result0_409, %result1_410 = torch.aten.var_mean.correction %295, %296, %int0_407, %true_408 : !torch.vtensor<[1,512,3072],f32>, !torch.list, !torch.int, !torch.bool -> !torch.vtensor<[1,512,1],f32>, !torch.vtensor<[1,512,1],f32> %float9.999990e-07_411 = torch.constant.float 9.9999999999999995E-7 %int1_412 = torch.constant.int 1 %297 = torch.aten.add.Scalar %result0_409, %float9.999990e-07_411, %int1_412 : !torch.vtensor<[1,512,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,512,1],f32> %298 = torch.aten.rsqrt %297 : !torch.vtensor<[1,512,1],f32> -> !torch.vtensor<[1,512,1],f32> %int1_413 = torch.constant.int 1 %299 = torch.aten.sub.Tensor %293, %result1_410, %int1_413 : !torch.vtensor<[1,512,3072],f16>, !torch.vtensor<[1,512,1],f32>, !torch.int -> !torch.vtensor<[1,512,3072],f32> %300 = torch.aten.mul.Tensor %299, %298 : !torch.vtensor<[1,512,3072],f32>, !torch.vtensor<[1,512,1],f32> -> !torch.vtensor<[1,512,3072],f32> %int5_414 = torch.constant.int 5 %301 = torch.prims.convert_element_type %300, %int5_414 : !torch.vtensor<[1,512,3072],f32>, !torch.int -> !torch.vtensor<[1,512,3072],f16> %302 = torch.aten.mul.Tensor %294, %301 : !torch.vtensor<[1,1,3072],f16>, !torch.vtensor<[1,512,3072],f16> -> !torch.vtensor<[1,512,3072],f16> %int1_415 = torch.constant.int 1 %303 = torch.aten.add.Tensor %302, %39, %int1_415 : !torch.vtensor<[1,512,3072],f16>, !torch.vtensor<[1,1,3072],f16>, !torch.int -> !torch.vtensor<[1,512,3072],f16> %int512_416 = torch.constant.int 512 %int3072_417 = torch.constant.int 3072 %304 = torch.prim.ListConstruct %int512_416, %int3072_417 : (!torch.int, !torch.int) -> !torch.list %305 = torch.aten.view %303, %304 : !torch.vtensor<[1,512,3072],f16>, !torch.list -> !torch.vtensor<[512,3072],f16> %__auto.attn.txt_mlp.0.weight = util.global.load @__auto.attn.txt_mlp.0.weight : tensor<12288x3072xf16> %306 = torch_c.from_builtin_tensor %__auto.attn.txt_mlp.0.weight : tensor<12288x3072xf16> -> !torch.vtensor<[12288,3072],f16> %int0_418 = torch.constant.int 0 %int1_419 = torch.constant.int 1 %307 = torch.aten.transpose.int %306, %int0_418, %int1_419 : !torch.vtensor<[12288,3072],f16>, !torch.int, !torch.int -> !torch.vtensor<[3072,12288],f16> %__auto.attn.txt_mlp.0.bias = util.global.load @__auto.attn.txt_mlp.0.bias : tensor<12288xf16> %308 = torch_c.from_builtin_tensor %__auto.attn.txt_mlp.0.bias : tensor<12288xf16> -> !torch.vtensor<[12288],f16> %int6_420 = torch.constant.int 6 %309 = torch.prims.convert_element_type %308, %int6_420 : !torch.vtensor<[12288],f16>, !torch.int -> !torch.vtensor<[12288],f32> %int6_421 = torch.constant.int 6 %310 = torch.prims.convert_element_type %305, %int6_421 : !torch.vtensor<[512,3072],f16>, !torch.int -> !torch.vtensor<[512,3072],f32> %int6_422 = torch.constant.int 6 %311 = torch.prims.convert_element_type %307, %int6_422 : !torch.vtensor<[3072,12288],f16>, !torch.int -> !torch.vtensor<[3072,12288],f32> %312 = torch.aten.mm %310, %311 : !torch.vtensor<[512,3072],f32>, !torch.vtensor<[3072,12288],f32> -> !torch.vtensor<[512,12288],f32> %int1_423 = torch.constant.int 1 %313 = torch.aten.mul.Scalar %312, %int1_423 : !torch.vtensor<[512,12288],f32>, !torch.int -> !torch.vtensor<[512,12288],f32> %int1_424 = torch.constant.int 1 %314 = torch.aten.mul.Scalar %309, %int1_424 : !torch.vtensor<[12288],f32>, !torch.int -> !torch.vtensor<[12288],f32> %int1_425 = torch.constant.int 1 %315 = torch.aten.add.Tensor %313, %314, %int1_425 : !torch.vtensor<[512,12288],f32>, !torch.vtensor<[12288],f32>, !torch.int -> !torch.vtensor<[512,12288],f32> %int5_426 = torch.constant.int 5 %316 = torch.prims.convert_element_type %315, %int5_426 : !torch.vtensor<[512,12288],f32>, !torch.int -> !torch.vtensor<[512,12288],f16> %int1_427 = torch.constant.int 1 %int512_428 = torch.constant.int 512 %int12288_429 = torch.constant.int 12288 %317 = torch.prim.ListConstruct %int1_427, %int512_428, %int12288_429 : (!torch.int, !torch.int, !torch.int) -> !torch.list %318 = torch.aten.view %316, %317 : !torch.vtensor<[512,12288],f16>, !torch.list -> !torch.vtensor<[1,512,12288],f16> %str_430 = torch.constant.str "tanh" %319 = torch.aten.gelu %318, %str_430 : !torch.vtensor<[1,512,12288],f16>, !torch.str -> !torch.vtensor<[1,512,12288],f16> %int512_431 = torch.constant.int 512 %int12288_432 = torch.constant.int 12288 %320 = torch.prim.ListConstruct %int512_431, %int12288_432 : (!torch.int, !torch.int) -> !torch.list %321 = torch.aten.view %319, %320 : !torch.vtensor<[1,512,12288],f16>, !torch.list -> !torch.vtensor<[512,12288],f16> %__auto.attn.txt_mlp.2.weight = util.global.load @__auto.attn.txt_mlp.2.weight : tensor<3072x12288xf16> %322 = torch_c.from_builtin_tensor %__auto.attn.txt_mlp.2.weight : tensor<3072x12288xf16> -> !torch.vtensor<[3072,12288],f16> %int0_433 = torch.constant.int 0 %int1_434 = torch.constant.int 1 %323 = torch.aten.transpose.int %322, %int0_433, %int1_434 : !torch.vtensor<[3072,12288],f16>, !torch.int, !torch.int -> !torch.vtensor<[12288,3072],f16> %__auto.attn.txt_mlp.2.bias = util.global.load @__auto.attn.txt_mlp.2.bias : tensor<3072xf16> %324 = torch_c.from_builtin_tensor %__auto.attn.txt_mlp.2.bias : tensor<3072xf16> -> !torch.vtensor<[3072],f16> %int6_435 = torch.constant.int 6 %325 = torch.prims.convert_element_type %324, %int6_435 : !torch.vtensor<[3072],f16>, !torch.int -> !torch.vtensor<[3072],f32> %int6_436 = torch.constant.int 6 %326 = torch.prims.convert_element_type %321, %int6_436 : !torch.vtensor<[512,12288],f16>, !torch.int -> !torch.vtensor<[512,12288],f32> %int6_437 = torch.constant.int 6 %327 = torch.prims.convert_element_type %323, %int6_437 : !torch.vtensor<[12288,3072],f16>, !torch.int -> !torch.vtensor<[12288,3072],f32> %328 = torch.aten.mm %326, %327 : !torch.vtensor<[512,12288],f32>, !torch.vtensor<[12288,3072],f32> -> !torch.vtensor<[512,3072],f32> %int1_438 = torch.constant.int 1 %329 = torch.aten.mul.Scalar %328, %int1_438 : !torch.vtensor<[512,3072],f32>, !torch.int -> !torch.vtensor<[512,3072],f32> %int1_439 = torch.constant.int 1 %330 = torch.aten.mul.Scalar %325, %int1_439 : !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[3072],f32> %int1_440 = torch.constant.int 1 %331 = torch.aten.add.Tensor %329, %330, %int1_440 : !torch.vtensor<[512,3072],f32>, !torch.vtensor<[3072],f32>, !torch.int -> !torch.vtensor<[512,3072],f32> %int5_441 = torch.constant.int 5 %332 = torch.prims.convert_element_type %331, %int5_441 : !torch.vtensor<[512,3072],f32>, !torch.int -> !torch.vtensor<[512,3072],f16> %int1_442 = torch.constant.int 1 %int512_443 = torch.constant.int 512 %int3072_444 = torch.constant.int 3072 %333 = torch.prim.ListConstruct %int1_442, %int512_443, %int3072_444 : (!torch.int, !torch.int, !torch.int) -> !torch.list %334 = torch.aten.view %332, %333 : !torch.vtensor<[512,3072],f16>, !torch.list -> !torch.vtensor<[1,512,3072],f16> %335 = torch.aten.mul.Tensor %41, %334 : !torch.vtensor<[1,1,3072],f16>, !torch.vtensor<[1,512,3072],f16> -> !torch.vtensor<[1,512,3072],f16> %int1_445 = torch.constant.int 1 %336 = torch.aten.add.Tensor %293, %335, %int1_445 : !torch.vtensor<[1,512,3072],f16>, !torch.vtensor<[1,512,3072],f16>, !torch.int -> !torch.vtensor<[1,512,3072],f16> return %276, %336 : !torch.vtensor<[1,4096,3072],f16>, !torch.vtensor<[1,512,3072],f16> } }