Gemma3ForCausalLM has no `_prepare_4d_causal_attention_mask_with_cache_position` method defined in its base modeling class
#57
by
buckeye17-bah
- opened
I'm trying to run the following code straight from the Gemma3 Hugging Face documentation here. This code produces a warning message and an error message that I've also shared below. I've isolated the error to be caused by the cache_implementation="static"
argument in the model.generate()
method. In other words, dropping that argument allows the model to run fine.
Based on the error message it seems that the cache implementation argument requires the triton Python package, which can only be installed on Linux platforms. I'm planning to test this in a Docker container soon to clarify whether this missing package is the true culprit or not. I'll report back then.
import torch
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
model = Gemma3ForConditionalGeneration.from_pretrained(
"google/gemma-3-4b-it",
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa"
)
processor = AutoProcessor.from_pretrained(
"google/gemma-3-4b-it",
padding_side="left"
)
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
]
},
{
"role": "user", "content": [
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
{"type": "text", "text": "What is shown in this image?"},
]
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to("cuda")
output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
print(processor.decode(output[0], skip_special_tokens=True))
Warning message:
Gemma3ForCausalLM has no `_prepare_4d_causal_attention_mask_with_cache_position` method defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're writing code, see Llama for an example implementation. If you're a user, please report this issue on GitHub.
Error message:
BackendCompilerFailed Traceback (most recent call last)
Cell In[1], line 37
15 messages = [
16 {
17 "role": "system",
(...) 27 },
28 ]
29 inputs = processor.apply_chat_template(
30 messages,
31 tokenize=True,
(...) 34 add_generation_prompt=True,
35 ).to("cuda")
---> 37 output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
38 print(processor.decode(output[0], skip_special_tokens=True))
File [my venv root dir]\gemma3\Lib\site-packages\torch\utils\_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File [my venv root dir]\gemma3\Lib\site-packages\transformers\generation\utils.py:2465, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, **kwargs)
2457 input_ids, model_kwargs = self._expand_inputs_for_generation(
2458 input_ids=input_ids,
2459 expand_size=generation_config.num_return_sequences,
2460 is_encoder_decoder=self.config.is_encoder_decoder,
2461 **model_kwargs,
2462 )
2464 # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2465 result = self._sample(
2466 input_ids,
2467 logits_processor=prepared_logits_processor,
2468 stopping_criteria=prepared_stopping_criteria,
2469 generation_config=generation_config,
2470 synced_gpus=synced_gpus,
2471 streamer=streamer,
2472 **model_kwargs,
2473 )
2475 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2476 # 11. interleave input_ids with `num_beams` additional sequences per batch
2477 input_ids, model_kwargs = self._expand_inputs_for_generation(
2478 input_ids=input_ids,
2479 expand_size=generation_config.num_beams,
2480 is_encoder_decoder=self.config.is_encoder_decoder,
2481 **model_kwargs,
2482 )
File [my venv root dir]\gemma3\Lib\site-packages\transformers\generation\utils.py:3434, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
3432 is_prefill = False
3433 else:
-> 3434 outputs = model_forward(**model_inputs, return_dict=True)
3436 # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
3437 model_kwargs = self._update_model_kwargs_for_generation(
3438 outputs,
3439 model_kwargs,
3440 is_encoder_decoder=self.config.is_encoder_decoder,
3441 )
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\eval_frame.py:574, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
569 saved_dynamic_layer_stack_depth = (
570 torch._C._functorch.get_dynamic_layer_stack_depth()
571 )
573 try:
--> 574 return fn(*args, **kwargs)
575 finally:
576 # Restore the dynamic layer stack depth if necessary.
577 torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
578 saved_dynamic_layer_stack_depth
579 )
File [my venv root dir]\gemma3\Lib\site-packages\torch\nn\modules\module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
File [my venv root dir]\gemma3\Lib\site-packages\torch\nn\modules\module.py:1750, in Module._call_impl(self, *args, **kwargs)
1745 # If we don't have any hooks, we want to skip the rest of the logic in
1746 # this function, and just call forward.
1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1752 result = None
1753 called_always_called_hooks = set()
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\convert_frame.py:1380, in CatchErrorsWrapper.__call__(self, frame, cache_entry, frame_state)
1374 return hijacked_callback(
1375 frame, cache_entry, self.hooks, frame_state
1376 )
1378 with compile_lock, _disable_current_modes():
1379 # skip=1: skip this frame
-> 1380 return self._torchdynamo_orig_callable(
1381 frame, cache_entry, self.hooks, frame_state, skip=1
1382 )
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\convert_frame.py:547, in ConvertFrameAssert.__call__(self, frame, cache_entry, hooks, frame_state, skip)
544 dynamo_tls.traced_frame_infos.append(info)
546 with compile_context(CompileContext(compile_id)):
--> 547 return _compile(
548 frame.f_code,
549 frame.f_globals,
550 frame.f_locals,
551 frame.f_builtins,
552 frame.closure,
553 self._torchdynamo_orig_callable,
554 self._one_graph,
555 self._export,
556 self._export_constraints,
557 hooks,
558 cache_entry,
559 cache_size,
560 frame,
561 frame_state=frame_state,
562 compile_id=compile_id,
563 skip=skip + 1,
564 )
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\convert_frame.py:986, in _compile(code, globals, locals, builtins, closure, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip)
984 guarded_code = None
985 try:
--> 986 guarded_code = compile_inner(code, one_graph, hooks, transform)
988 # NB: We only put_code_state in success case. Success case here
989 # does include graph breaks; specifically, if a graph break still
990 # resulted in a partially compiled graph, we WILL return here. An
(...) 995 # to upload for graph break though, because this can prevent
996 # extra graph break compilations.)
997 put_code_state()
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\convert_frame.py:715, in _compile.<locals>.compile_inner(code, one_graph, hooks, transform)
713 stack.enter_context(torch._dynamo.callback_handler.install_callbacks())
714 stack.enter_context(CompileTimeInstructionCounter.record())
--> 715 return _compile_inner(code, one_graph, hooks, transform)
717 return None
File [my venv root dir]\gemma3\Lib\site-packages\torch\_utils_internal.py:95, in compile_time_strobelight_meta.<locals>.compile_time_strobelight_meta_inner.<locals>.wrapper_function(*args, **kwargs)
92 kwargs["skip"] = skip + 1
94 if not StrobelightCompileTimeProfiler.enabled:
---> 95 return function(*args, **kwargs)
97 return StrobelightCompileTimeProfiler.profile_compile_time(
98 function, phase_name, *args, **kwargs
99 )
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\convert_frame.py:750, in _compile.<locals>._compile_inner(code, one_graph, hooks, transform)
748 CompileContext.get().attempt = attempt
749 try:
--> 750 out_code = transform_code_object(code, transform)
751 break
752 except exc.RestartAnalysis as e:
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\bytecode_transformation.py:1361, in transform_code_object(code, transformations, safe)
1358 instructions = cleaned_instructions(code, safe)
1359 propagate_line_nums(instructions)
-> 1361 transformations(instructions, code_options)
1362 return clean_and_assemble_instructions(instructions, keys, code_options)[1]
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\convert_frame.py:231, in preserve_global_state.<locals>._fn(*args, **kwargs)
229 exit_stack.enter_context(torch_function_mode_stack_state_mgr)
230 try:
--> 231 return fn(*args, **kwargs)
232 finally:
233 cleanup.close()
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\convert_frame.py:662, in _compile.<locals>.transform(instructions, code_options)
660 try:
661 with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 662 tracer.run()
663 except exc.UnspecializeRestartAnalysis:
664 speculation_log.clear()
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\symbolic_convert.py:2868, in InstructionTranslator.run(self)
2867 def run(self):
-> 2868 super().run()
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\symbolic_convert.py:1052, in InstructionTranslatorBase.run(self)
1050 try:
1051 self.output.push_tx(self)
-> 1052 while self.step():
1053 pass
1054 except TensorifyScalarRestartAnalysis:
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\symbolic_convert.py:962, in InstructionTranslatorBase.step(self)
959 self.update_block_stack(inst)
961 try:
--> 962 self.dispatch_table[inst.opcode](self, inst)
963 return not self.output.should_exit
964 except TensorifyScalarRestartAnalysis:
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\symbolic_convert.py:3048, in InstructionTranslator.RETURN_VALUE(self, inst)
3047 def RETURN_VALUE(self, inst):
-> 3048 self._return(inst)
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\symbolic_convert.py:3033, in InstructionTranslator._return(self, inst)
3028 _step_logger()(
3029 logging.INFO,
3030 f"torchdynamo done tracing {self.f_code.co_name} ({inst.opname})",
3031 )
3032 log.debug("%s triggered compile", inst.opname)
-> 3033 self.output.compile_subgraph(
3034 self,
3035 reason=GraphCompileReason(
3036 "return_value", [self.frame_summary()], graph_break=False
3037 ),
3038 )
3039 return_inst = (
3040 create_instruction("RETURN_VALUE")
3041 if inst.opname == "RETURN_VALUE"
3042 else create_instruction("RETURN_CONST", argval=inst.argval)
3043 )
3044 self.output.add_output_instructions([return_inst])
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\output_graph.py:1136, in OutputGraph.compile_subgraph(self, tx, partial_convert, reason)
1133 output = []
1134 if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
1135 output.extend(
-> 1136 self.compile_and_call_fx_graph(
1137 tx, pass2.graph_output_vars(), root, output_replacements
1138 )
1139 )
1141 if len(pass2.graph_outputs) != 0:
1142 output.append(pass2.create_store(graph_output_var))
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\output_graph.py:1382, in OutputGraph.compile_and_call_fx_graph(self, tx, rv, root, replaced_outputs)
1379 self.tracing_context.fake_mode = backend_fake_mode
1381 with self.restore_global_state():
-> 1382 compiled_fn = self.call_user_compiler(gm)
1384 from torch.fx._lazy_graph_module import _LazyGraphModule
1386 if isinstance(compiled_fn, _LazyGraphModule) or (
1387 isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule)
1388 and compiled_fn.__name__ == "_lazy_forward" # type: ignore[attr-defined]
(...) 1392 # this is a _LazyGraphModule. This makes it easier for dynamo to
1393 # optimize a _LazyGraphModule.
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\output_graph.py:1432, in OutputGraph.call_user_compiler(self, gm)
1425 def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
1426 with dynamo_timed(
1427 "OutputGraph.call_user_compiler",
1428 phase_name="backend_compile",
1429 log_pt2_compile_event=True,
1430 dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
1431 ):
-> 1432 return self._call_user_compiler(gm)
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\output_graph.py:1483, in OutputGraph._call_user_compiler(self, gm)
1481 raise e
1482 except Exception as e:
-> 1483 raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
1484 e.__traceback__
1485 ) from None
1487 signpost_event(
1488 "dynamo",
1489 "OutputGraph.call_user_compiler",
(...) 1495 },
1496 )
1498 return compiled_fn
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\output_graph.py:1462, in OutputGraph._call_user_compiler(self, gm)
1460 if config.verify_correctness:
1461 compiler_fn = WrapperBackend(compiler_fn)
-> 1462 compiled_fn = compiler_fn(gm, self.example_inputs())
1463 _step_logger()(logging.INFO, f"done compiler function {name}")
1464 assert callable(compiled_fn), "compiler_fn did not return callable"
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\repro\after_dynamo.py:130, in WrapBackendDebug.__call__(self, gm, example_inputs, **kwargs)
128 raise
129 else:
--> 130 compiled_gm = compiler_fn(gm, example_inputs)
132 return compiled_gm
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\repro\after_dynamo.py:130, in WrapBackendDebug.__call__(self, gm, example_inputs, **kwargs)
128 raise
129 else:
--> 130 compiled_gm = compiler_fn(gm, example_inputs)
132 return compiled_gm
File [my venv root dir]\gemma3\Lib\site-packages\torch\__init__.py:2340, in _TorchCompileInductorWrapper.__call__(self, model_, inputs_)
2337 def __call__(self, model_, inputs_):
2338 from torch._inductor.compile_fx import compile_fx
-> 2340 return compile_fx(model_, inputs_, config_patches=self.config)
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\compile_fx.py:1552, in compile_fx(model_, example_inputs_, inner_compile, config_patches, decompositions)
1550 if config_patches:
1551 with config.patch(config_patches):
-> 1552 return compile_fx(
1553 model_,
1554 example_inputs_,
1555 # need extra layer of patching as backwards is compiled out of scope
1556 inner_compile=config.patch(config_patches)(inner_compile),
1557 decompositions=decompositions,
1558 )
1560 # TODO: This probably shouldn't be a recursive call
1561 if config.cpp_wrapper:
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\compile_fx.py:1863, in compile_fx(model_, example_inputs_, inner_compile, config_patches, decompositions)
1856 return inference_compiler(unlifted_gm, example_inputs_)
1858 with V.set_fake_mode(fake_mode), torch._guards.tracing(
1859 tracing_context
1860 ), compiled_autograd._disable(), functorch_config.patch(
1861 unlift_effect_tokens=True
1862 ):
-> 1863 return aot_autograd(
1864 fw_compiler=fw_compiler,
1865 bw_compiler=bw_compiler,
1866 inference_compiler=inference_compiler,
1867 decompositions=decompositions,
1868 partition_fn=partition_fn,
1869 keep_inference_input_mutations=True,
1870 cudagraphs=cudagraphs,
1871 )(model_, example_inputs_)
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\backends\common.py:83, in AotAutograd.__call__(self, gm, example_inputs, **kwargs)
80 try:
81 # NB: NOT cloned!
82 with enable_aot_logging(), patch_config:
---> 83 cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
84 counters["aot_autograd"]["ok"] += 1
85 return disable(cg)
File [my venv root dir]\gemma3\Lib\site-packages\torch\_functorch\aot_autograd.py:1155, in aot_module_simplified(mod, args, fw_compiler, bw_compiler, partition_fn, decompositions, keep_inference_input_mutations, inference_compiler, cudagraphs)
1145 compiled_fn = AOTAutogradCache.load(
1146 dispatch_and_compile,
1147 mod,
(...) 1152 remote,
1153 )
1154 else:
-> 1155 compiled_fn = dispatch_and_compile()
1157 if isinstance(mod, torch._dynamo.utils.GmWrapper):
1158 # This function is called by the flatten_graph_inputs wrapper, which boxes
1159 # the inputs so that they can be freed before the end of this scope.
1160 # For overhead reasons, this is not the default wrapper, see comment:
1161 # https://github.com/pytorch/pytorch/pull/122535/files#r1560096481
1162 def boxed_forward(runtime_args: List[Any]):
File [my venv root dir]\gemma3\Lib\site-packages\torch\_functorch\aot_autograd.py:1131, in aot_module_simplified.<locals>.dispatch_and_compile()
1129 functional_call = create_functional_call(mod, params_spec, params_len)
1130 with compiled_autograd._disable():
-> 1131 compiled_fn, _ = create_aot_dispatcher_function(
1132 functional_call,
1133 fake_flat_args,
1134 aot_config,
1135 fake_mode,
1136 shape_env,
1137 )
1138 return compiled_fn
File [my venv root dir]\gemma3\Lib\site-packages\torch\_functorch\aot_autograd.py:580, in create_aot_dispatcher_function(flat_fn, fake_flat_args, aot_config, fake_mode, shape_env)
572 def create_aot_dispatcher_function(
573 flat_fn,
574 fake_flat_args: FakifiedFlatArgs,
(...) 577 shape_env: Optional[ShapeEnv],
578 ) -> Tuple[Callable, ViewAndMutationMeta]:
579 with dynamo_timed("create_aot_dispatcher_function", log_pt2_compile_event=True):
--> 580 return _create_aot_dispatcher_function(
581 flat_fn, fake_flat_args, aot_config, fake_mode, shape_env
582 )
File [my venv root dir]\gemma3\Lib\site-packages\torch\_functorch\aot_autograd.py:830, in _create_aot_dispatcher_function(flat_fn, fake_flat_args, aot_config, fake_mode, shape_env)
826 return aot_dispatch_base
828 compiler_fn = choose_dispatcher(needs_autograd, aot_config)
--> 830 compiled_fn, fw_metadata = compiler_fn(
831 flat_fn,
832 _dup_fake_script_obj(fake_flat_args),
833 aot_config,
834 fw_metadata=fw_metadata,
835 )
836 return compiled_fn, fw_metadata
File [my venv root dir]\gemma3\Lib\site-packages\torch\_functorch\_aot_autograd\jit_compile_runtime_wrappers.py:203, in aot_dispatch_base(flat_fn, flat_args, aot_config, fw_metadata)
201 assert isinstance(fw_module, GraphModule)
202 tensorify_python_scalars(fw_module, fake_mode.shape_env, fake_mode)
--> 203 compiled_fw = compiler(fw_module, updated_flat_args)
205 if fakified_out_wrapper.needs_post_compile:
206 fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)
File [my venv root dir]\gemma3\Lib\site-packages\torch\_functorch\aot_autograd.py:489, in SerializableAOTDispatchCompiler.__call__(self, gm, example_inputs)
484 def __call__(
485 self,
486 gm: torch.fx.GraphModule,
487 example_inputs: Sequence[InputType],
488 ) -> OutputCode:
--> 489 return self.compiler_fn(gm, example_inputs)
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\compile_fx.py:1741, in compile_fx.<locals>.fw_compiler_base(gm, example_inputs, is_inference)
1738 else:
1739 model_outputs_node.meta["user_visible_output_idxs"] = []
-> 1741 return inner_compile(
1742 gm,
1743 example_inputs,
1744 static_input_idxs=get_static_input_idxs(fixed),
1745 cudagraphs=cudagraphs,
1746 graph_id=graph_id,
1747 is_inference=is_inference,
1748 boxed_forward_device_index=forward_device,
1749 )
File C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.12_3.12.2800.0_x64__qbz5n2kfra8p0\Lib\contextlib.py:81, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
78
@wraps
(func)
79 def inner(*args, **kwds):
80 with self._recreate_cm():
---> 81 return func(*args, **kwds)
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\compile_fx.py:569, in compile_fx_inner(gm, example_inputs, **kwargs)
562 stack.enter_context(DebugContext())
564 get_chromium_event_logger().add_event_data(
565 "inductor_compile",
566 is_backward=kwargs["is_backward"],
567 )
--> 569 return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
570 gm,
571 example_inputs,
572 **kwargs,
573 )
File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\repro\after_aot.py:102, in wrap_compiler_debug.<locals>.debug_wrapper(gm, example_inputs, **kwargs)
97 assert config.repro_after in ("dynamo", "aot", None)
99 try:
100 # Call the compiler_fn - which is either aot_autograd or inductor
101 # with fake inputs
--> 102 inner_compiled_fn = compiler_fn(gm, example_inputs)
103 except Exception as e:
104 # TODO: Failures here are troublesome because no real inputs,
105 # need a different serialization strategy
106 if config.repro_after == "aot":
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\compile_fx.py:685, in _compile_fx_inner(gm, example_inputs, **graph_kwargs)
683 TritonBundler.begin_compile()
684 try:
--> 685 mb_compiled_graph = fx_codegen_and_compile(
686 gm, example_inputs, inputs_to_check, **graph_kwargs
687 )
688 assert mb_compiled_graph is not None
689 mb_compiled_graph._time_taken_ns = time.time_ns() - start_time
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\compile_fx.py:1129, in fx_codegen_and_compile(gm, example_inputs, inputs_to_check, **graph_kwargs)
1119 def fx_codegen_and_compile(
1120 gm: GraphModule,
1121 example_inputs: Sequence[InputType],
(...) 1125 **graph_kwargs: Unpack[_CompileFxKwargs],
1126 ) -> OutputCode:
1127 scheme: FxCompile = _InProcessFxCompile()
-> 1129 return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\compile_fx.py:1044, in _InProcessFxCompile.codegen_and_compile(self, gm, example_inputs, inputs_to_check, graph_kwargs)
1036 compiled_fn = AotCodeCompiler.compile(
1037 graph,
1038 code,
(...) 1041 additional_files=additional_files,
1042 )
1043 else:
-> 1044 compiled_fn = graph.compile_to_module().call
1046 num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
1047 metrics.num_bytes_accessed += num_bytes
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\graph.py:2027, in GraphLowering.compile_to_module(self)
2020 def compile_to_module(self) -> ModuleType:
2021 with dynamo_timed(
2022 "GraphLowering.compile_to_module",
2023 phase_name="code_gen",
2024 log_pt2_compile_event=True,
2025 dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us",
2026 ):
-> 2027 return self._compile_to_module()
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\graph.py:2033, in GraphLowering._compile_to_module(self)
2029 def _compile_to_module(self) -> ModuleType:
2030 from .codecache import PyCodeCache
2032 code, linemap = (
-> 2033 self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
2034 )
2035 if config.triton.autotune_at_compile_time:
2036 tuning_code = (
2037 '"""\n'
2038 + "Compile-time auto-tuning block: \n"
(...) 2041 + '"""\n'
2042 )
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\graph.py:1964, in GraphLowering.codegen(self)
1960 from .scheduler import Scheduler
1962 self.init_wrapper_code()
-> 1964 self.scheduler = Scheduler(self.operations)
1965 V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
1967 self.wrapper_code.push_codegened_graph(self)
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\scheduler.py:1798, in Scheduler.__init__(self, nodes)
1796 def __init__(self, nodes: List[ir.Operation]) -> None:
1797 with dynamo_timed("Scheduler.__init__"):
-> 1798 self._init(nodes)
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\scheduler.py:1816, in Scheduler._init(self, nodes)
1807 self.completed_operations: OrderedSet[str] = OrderedSet()
1808 self.available_buffer_names = OrderedSet(
1809 [
1810 *V.graph.graph_inputs.keys(),
(...) 1813 ]
1814 )
-> 1816 self.nodes = [self.create_scheduler_node(n) for n in nodes]
1817 self.update_zero_dim_cpu_tensor()
1818 # some new constants could have been created above
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\scheduler.py:1947, in Scheduler.create_scheduler_node(self, node)
1945 return NopKernelSchedulerNode(self, node)
1946 elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)):
-> 1947 return SchedulerNode(self, node)
1948 elif isinstance(node, ir.ExternKernel):
1949 return ExternKernelSchedulerNode(self, node)
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\scheduler.py:893, in SchedulerNode.__init__(self, scheduler, node)
891 super().__init__(scheduler)
892 self._init_from_node(node)
--> 893 self._compute_attrs()
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\scheduler.py:907, in SchedulerNode._compute_attrs(self, extra_indexing_constraints, recompute_sizes_body_func)
901 self._sizes, self._body = self.node.simplify_and_reorder(
902 extra_indexing_constraints=extra_indexing_constraints,
903 recompute_sizes_body_func=recompute_sizes_body_func,
904 )
906 device = self.node.get_device_or_error()
--> 907 group_fn = self.scheduler.get_backend(device).group_fn
908 self.group = (device, group_fn(self._sizes))
910 # Don't normalize since normalization will merge loops which
911 # makes it hard to decide new loop orders.
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\scheduler.py:3441, in Scheduler.get_backend(self, device)
3439 assert device is not None
3440 if device not in self.backends:
-> 3441 self.backends[device] = self.create_backend(device)
3442 return self.backends[device]
File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\scheduler.py:3432, in Scheduler.create_backend(self, device)
3428 raise RuntimeError(
3429 f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}" # noqa: B950
3430 )
3431 elif is_gpu(device.type):
-> 3432 raise RuntimeError(
3433 "Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950
3434 )
3436 return device_scheduling(self)
BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at https://github.com/openai/triton
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Python environment:
Package Version
------------------------- --------------
accelerate 1.6.0
anyio 4.9.0
argon2-cffi 23.1.0
argon2-cffi-bindings 21.2.0
arrow 1.3.0
asttokens 3.0.0
async-lru 2.0.5
attrs 25.3.0
babel 2.17.0
beautifulsoup4 4.13.4
bleach 6.2.0
certifi 2025.1.31
cffi 1.17.1
charset-normalizer 3.4.1
colorama 0.4.6
comm 0.2.2
debugpy 1.8.14
decorator 5.2.1
defusedxml 0.7.1
executing 2.2.0
fastjsonschema 2.21.1
filelock 3.13.1
fqdn 1.5.1
fsspec 2024.6.1
h11 0.14.0
httpcore 1.0.8
httpx 0.28.1
huggingface-hub 0.30.2
idna 3.10
ipykernel 6.29.5
ipython 9.1.0
ipython_pygments_lexers 1.1.1
ipywidgets 8.1.6
isoduration 20.11.0
jedi 0.19.2
Jinja2 3.1.4
json5 0.12.0
jsonpointer 3.0.0
jsonschema 4.23.0
jsonschema-specifications 2024.10.1
jupyter_client 8.6.3
jupyter_core 5.7.2
jupyter-events 0.12.0
jupyter-lsp 2.2.5
jupyter_server 2.15.0
jupyter_server_terminals 0.5.3
jupyterlab 4.4.0
jupyterlab_pygments 0.3.0
jupyterlab_server 2.27.3
jupyterlab_widgets 3.0.14
kaleido 0.1.0.post1
MarkupSafe 2.1.5
matplotlib-inline 0.1.7
mistune 3.1.3
mpmath 1.3.0
narwhals 1.35.0
nbclient 0.10.2
nbconvert 7.16.6
nbformat 5.10.4
nest-asyncio 1.6.0
networkx 3.3
notebook 7.4.0
notebook_shim 0.2.4
numpy 2.1.2
overrides 7.7.0
packaging 24.2
pandas 2.2.3
pandocfilters 1.5.1
parso 0.8.4
pillow 11.0.0
pip 25.0.1
platformdirs 4.3.7
plotly 6.0.1
prometheus_client 0.21.1
prompt_toolkit 3.0.51
psutil 7.0.0
pure_eval 0.2.3
pycparser 2.22
Pygments 2.19.1
python-dateutil 2.9.0.post0
python-json-logger 3.3.0
pytz 2025.2
pywin32 310
pywinpty 2.0.15
PyYAML 6.0.2
pyzmq 26.4.0
referencing 0.36.2
regex 2024.11.6
requests 2.32.3
rfc3339-validator 0.1.4
rfc3986-validator 0.1.1
rpds-py 0.24.0
safetensors 0.5.3
Send2Trash 1.8.3
setuptools 70.2.0
six 1.17.0
sniffio 1.3.1
soupsieve 2.6
stack-data 0.6.3
sympy 1.13.1
terminado 0.18.1
tinycss2 1.4.0
tokenizers 0.21.1
torch 2.6.0+cu126
torchaudio 2.6.0+cu126
torchvision 0.21.0+cu126
tornado 6.4.2
tqdm 4.67.1
traitlets 5.14.3
transformers 4.51.3
types-python-dateutil 2.9.0.20241206
typing_extensions 4.12.2
tzdata 2025.2
uri-template 1.3.0
urllib3 2.4.0
wcwidth 0.2.13
webcolors 24.11.1
webencodings 0.5.1
websocket-client 1.8.0
widgetsnbextension 4.0.14