Gemma3ForCausalLM has no `_prepare_4d_causal_attention_mask_with_cache_position` method defined in its base modeling class

#57
by buckeye17-bah - opened

I'm trying to run the following code straight from the Gemma3 Hugging Face documentation here. This code produces a warning message and an error message that I've also shared below. I've isolated the error to be caused by the cache_implementation="static" argument in the model.generate() method. In other words, dropping that argument allows the model to run fine.

Based on the error message it seems that the cache implementation argument requires the triton Python package, which can only be installed on Linux platforms. I'm planning to test this in a Docker container soon to clarify whether this missing package is the true culprit or not. I'll report back then.

import torch
from transformers import AutoProcessor, Gemma3ForConditionalGeneration

model = Gemma3ForConditionalGeneration.from_pretrained(
    "google/gemma-3-4b-it",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa"
)
processor = AutoProcessor.from_pretrained(
    "google/gemma-3-4b-it",
    padding_side="left"
)

messages = [
    {
        "role": "system",
        "content": [
            {"type": "text", "text": "You are a helpful assistant."}
        ]
    },
    {
        "role": "user", "content": [
            {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
            {"type": "text", "text": "What is shown in this image?"},
        ]
    },
]
inputs = processor.apply_chat_template(
    messages,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
    add_generation_prompt=True,
).to("cuda")

output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
print(processor.decode(output[0], skip_special_tokens=True))

Warning message:

Gemma3ForCausalLM has no `_prepare_4d_causal_attention_mask_with_cache_position` method defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're writing code, see Llama for an example implementation. If you're a user, please report this issue on GitHub.

Error message:

BackendCompilerFailed                     Traceback (most recent call last)
Cell In[1], line 37
     15 messages = [
     16     {
     17         "role": "system",
   (...)     27     },
     28 ]
     29 inputs = processor.apply_chat_template(
     30     messages,
     31     tokenize=True,
   (...)     34     add_generation_prompt=True,
     35 ).to("cuda")
---> 37 output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
     38 print(processor.decode(output[0], skip_special_tokens=True))

File [my venv root dir]\gemma3\Lib\site-packages\torch\utils\_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File [my venv root dir]\gemma3\Lib\site-packages\transformers\generation\utils.py:2465, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, **kwargs)
   2457     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2458         input_ids=input_ids,
   2459         expand_size=generation_config.num_return_sequences,
   2460         is_encoder_decoder=self.config.is_encoder_decoder,
   2461         **model_kwargs,
   2462     )
   2464     # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2465     result = self._sample(
   2466         input_ids,
   2467         logits_processor=prepared_logits_processor,
   2468         stopping_criteria=prepared_stopping_criteria,
   2469         generation_config=generation_config,
   2470         synced_gpus=synced_gpus,
   2471         streamer=streamer,
   2472         **model_kwargs,
   2473     )
   2475 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2476     # 11. interleave input_ids with `num_beams` additional sequences per batch
   2477     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2478         input_ids=input_ids,
   2479         expand_size=generation_config.num_beams,
   2480         is_encoder_decoder=self.config.is_encoder_decoder,
   2481         **model_kwargs,
   2482     )

File [my venv root dir]\gemma3\Lib\site-packages\transformers\generation\utils.py:3434, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   3432     is_prefill = False
   3433 else:
-> 3434     outputs = model_forward(**model_inputs, return_dict=True)
   3436 # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
   3437 model_kwargs = self._update_model_kwargs_for_generation(
   3438     outputs,
   3439     model_kwargs,
   3440     is_encoder_decoder=self.config.is_encoder_decoder,
   3441 )

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\eval_frame.py:574, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    569 saved_dynamic_layer_stack_depth = (
    570     torch._C._functorch.get_dynamic_layer_stack_depth()
    571 )
    573 try:
--> 574     return fn(*args, **kwargs)
    575 finally:
    576     # Restore the dynamic layer stack depth if necessary.
    577     torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
    578         saved_dynamic_layer_stack_depth
    579     )

File [my venv root dir]\gemma3\Lib\site-packages\torch\nn\modules\module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File [my venv root dir]\gemma3\Lib\site-packages\torch\nn\modules\module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\convert_frame.py:1380, in CatchErrorsWrapper.__call__(self, frame, cache_entry, frame_state)
   1374             return hijacked_callback(
   1375                 frame, cache_entry, self.hooks, frame_state
   1376             )
   1378 with compile_lock, _disable_current_modes():
   1379     # skip=1: skip this frame
-> 1380     return self._torchdynamo_orig_callable(
   1381         frame, cache_entry, self.hooks, frame_state, skip=1
   1382     )

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\convert_frame.py:547, in ConvertFrameAssert.__call__(self, frame, cache_entry, hooks, frame_state, skip)
    544     dynamo_tls.traced_frame_infos.append(info)
    546 with compile_context(CompileContext(compile_id)):
--> 547     return _compile(
    548         frame.f_code,
    549         frame.f_globals,
    550         frame.f_locals,
    551         frame.f_builtins,
    552         frame.closure,
    553         self._torchdynamo_orig_callable,
    554         self._one_graph,
    555         self._export,
    556         self._export_constraints,
    557         hooks,
    558         cache_entry,
    559         cache_size,
    560         frame,
    561         frame_state=frame_state,
    562         compile_id=compile_id,
    563         skip=skip + 1,
    564     )

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\convert_frame.py:986, in _compile(code, globals, locals, builtins, closure, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip)
    984 guarded_code = None
    985 try:
--> 986     guarded_code = compile_inner(code, one_graph, hooks, transform)
    988     # NB: We only put_code_state in success case.  Success case here
    989     # does include graph breaks; specifically, if a graph break still
    990     # resulted in a partially compiled graph, we WILL return here.  An
   (...)    995     # to upload for graph break though, because this can prevent
    996     # extra graph break compilations.)
    997     put_code_state()

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\convert_frame.py:715, in _compile.<locals>.compile_inner(code, one_graph, hooks, transform)
    713     stack.enter_context(torch._dynamo.callback_handler.install_callbacks())
    714     stack.enter_context(CompileTimeInstructionCounter.record())
--> 715     return _compile_inner(code, one_graph, hooks, transform)
    717 return None

File [my venv root dir]\gemma3\Lib\site-packages\torch\_utils_internal.py:95, in compile_time_strobelight_meta.<locals>.compile_time_strobelight_meta_inner.<locals>.wrapper_function(*args, **kwargs)
     92     kwargs["skip"] = skip + 1
     94 if not StrobelightCompileTimeProfiler.enabled:
---> 95     return function(*args, **kwargs)
     97 return StrobelightCompileTimeProfiler.profile_compile_time(
     98     function, phase_name, *args, **kwargs
     99 )

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\convert_frame.py:750, in _compile.<locals>._compile_inner(code, one_graph, hooks, transform)
    748 CompileContext.get().attempt = attempt
    749 try:
--> 750     out_code = transform_code_object(code, transform)
    751     break
    752 except exc.RestartAnalysis as e:

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\bytecode_transformation.py:1361, in transform_code_object(code, transformations, safe)
   1358 instructions = cleaned_instructions(code, safe)
   1359 propagate_line_nums(instructions)
-> 1361 transformations(instructions, code_options)
   1362 return clean_and_assemble_instructions(instructions, keys, code_options)[1]

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\convert_frame.py:231, in preserve_global_state.<locals>._fn(*args, **kwargs)
    229 exit_stack.enter_context(torch_function_mode_stack_state_mgr)
    230 try:
--> 231     return fn(*args, **kwargs)
    232 finally:
    233     cleanup.close()

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\convert_frame.py:662, in _compile.<locals>.transform(instructions, code_options)
    660 try:
    661     with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 662         tracer.run()
    663 except exc.UnspecializeRestartAnalysis:
    664     speculation_log.clear()

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\symbolic_convert.py:2868, in InstructionTranslator.run(self)
   2867 def run(self):
-> 2868     super().run()

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\symbolic_convert.py:1052, in InstructionTranslatorBase.run(self)
   1050 try:
   1051     self.output.push_tx(self)
-> 1052     while self.step():
   1053         pass
   1054 except TensorifyScalarRestartAnalysis:

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\symbolic_convert.py:962, in InstructionTranslatorBase.step(self)
    959 self.update_block_stack(inst)
    961 try:
--> 962     self.dispatch_table[inst.opcode](self, inst)
    963     return not self.output.should_exit
    964 except TensorifyScalarRestartAnalysis:

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\symbolic_convert.py:3048, in InstructionTranslator.RETURN_VALUE(self, inst)
   3047 def RETURN_VALUE(self, inst):
-> 3048     self._return(inst)

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\symbolic_convert.py:3033, in InstructionTranslator._return(self, inst)
   3028 _step_logger()(
   3029     logging.INFO,
   3030     f"torchdynamo done tracing {self.f_code.co_name} ({inst.opname})",
   3031 )
   3032 log.debug("%s triggered compile", inst.opname)
-> 3033 self.output.compile_subgraph(
   3034     self,
   3035     reason=GraphCompileReason(
   3036         "return_value", [self.frame_summary()], graph_break=False
   3037     ),
   3038 )
   3039 return_inst = (
   3040     create_instruction("RETURN_VALUE")
   3041     if inst.opname == "RETURN_VALUE"
   3042     else create_instruction("RETURN_CONST", argval=inst.argval)
   3043 )
   3044 self.output.add_output_instructions([return_inst])

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\output_graph.py:1136, in OutputGraph.compile_subgraph(self, tx, partial_convert, reason)
   1133 output = []
   1134 if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
   1135     output.extend(
-> 1136         self.compile_and_call_fx_graph(
   1137             tx, pass2.graph_output_vars(), root, output_replacements
   1138         )
   1139     )
   1141     if len(pass2.graph_outputs) != 0:
   1142         output.append(pass2.create_store(graph_output_var))

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\output_graph.py:1382, in OutputGraph.compile_and_call_fx_graph(self, tx, rv, root, replaced_outputs)
   1379     self.tracing_context.fake_mode = backend_fake_mode
   1381 with self.restore_global_state():
-> 1382     compiled_fn = self.call_user_compiler(gm)
   1384 from torch.fx._lazy_graph_module import _LazyGraphModule
   1386 if isinstance(compiled_fn, _LazyGraphModule) or (
   1387     isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule)
   1388     and compiled_fn.__name__ == "_lazy_forward"  # type: ignore[attr-defined]
   (...)   1392     # this is a _LazyGraphModule. This makes it easier for dynamo to
   1393     # optimize a _LazyGraphModule.

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\output_graph.py:1432, in OutputGraph.call_user_compiler(self, gm)
   1425 def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
   1426     with dynamo_timed(
   1427         "OutputGraph.call_user_compiler",
   1428         phase_name="backend_compile",
   1429         log_pt2_compile_event=True,
   1430         dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
   1431     ):
-> 1432         return self._call_user_compiler(gm)

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\output_graph.py:1483, in OutputGraph._call_user_compiler(self, gm)
   1481     raise e
   1482 except Exception as e:
-> 1483     raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
   1484         e.__traceback__
   1485     ) from None
   1487 signpost_event(
   1488     "dynamo",
   1489     "OutputGraph.call_user_compiler",
   (...)   1495     },
   1496 )
   1498 return compiled_fn

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\output_graph.py:1462, in OutputGraph._call_user_compiler(self, gm)
   1460 if config.verify_correctness:
   1461     compiler_fn = WrapperBackend(compiler_fn)
-> 1462 compiled_fn = compiler_fn(gm, self.example_inputs())
   1463 _step_logger()(logging.INFO, f"done compiler function {name}")
   1464 assert callable(compiled_fn), "compiler_fn did not return callable"

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\repro\after_dynamo.py:130, in WrapBackendDebug.__call__(self, gm, example_inputs, **kwargs)
    128             raise
    129 else:
--> 130     compiled_gm = compiler_fn(gm, example_inputs)
    132 return compiled_gm

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\repro\after_dynamo.py:130, in WrapBackendDebug.__call__(self, gm, example_inputs, **kwargs)
    128             raise
    129 else:
--> 130     compiled_gm = compiler_fn(gm, example_inputs)
    132 return compiled_gm

File [my venv root dir]\gemma3\Lib\site-packages\torch\__init__.py:2340, in _TorchCompileInductorWrapper.__call__(self, model_, inputs_)
   2337 def __call__(self, model_, inputs_):
   2338     from torch._inductor.compile_fx import compile_fx
-> 2340     return compile_fx(model_, inputs_, config_patches=self.config)

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\compile_fx.py:1552, in compile_fx(model_, example_inputs_, inner_compile, config_patches, decompositions)
   1550 if config_patches:
   1551     with config.patch(config_patches):
-> 1552         return compile_fx(
   1553             model_,
   1554             example_inputs_,
   1555             # need extra layer of patching as backwards is compiled out of scope
   1556             inner_compile=config.patch(config_patches)(inner_compile),
   1557             decompositions=decompositions,
   1558         )
   1560 # TODO: This probably shouldn't be a recursive call
   1561 if config.cpp_wrapper:

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\compile_fx.py:1863, in compile_fx(model_, example_inputs_, inner_compile, config_patches, decompositions)
   1856         return inference_compiler(unlifted_gm, example_inputs_)
   1858 with V.set_fake_mode(fake_mode), torch._guards.tracing(
   1859     tracing_context
   1860 ), compiled_autograd._disable(), functorch_config.patch(
   1861     unlift_effect_tokens=True
   1862 ):
-> 1863     return aot_autograd(
   1864         fw_compiler=fw_compiler,
   1865         bw_compiler=bw_compiler,
   1866         inference_compiler=inference_compiler,
   1867         decompositions=decompositions,
   1868         partition_fn=partition_fn,
   1869         keep_inference_input_mutations=True,
   1870         cudagraphs=cudagraphs,
   1871     )(model_, example_inputs_)

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\backends\common.py:83, in AotAutograd.__call__(self, gm, example_inputs, **kwargs)
     80 try:
     81     # NB: NOT cloned!
     82     with enable_aot_logging(), patch_config:
---> 83         cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
     84         counters["aot_autograd"]["ok"] += 1
     85         return disable(cg)

File [my venv root dir]\gemma3\Lib\site-packages\torch\_functorch\aot_autograd.py:1155, in aot_module_simplified(mod, args, fw_compiler, bw_compiler, partition_fn, decompositions, keep_inference_input_mutations, inference_compiler, cudagraphs)
   1145     compiled_fn = AOTAutogradCache.load(
   1146         dispatch_and_compile,
   1147         mod,
   (...)   1152         remote,
   1153     )
   1154 else:
-> 1155     compiled_fn = dispatch_and_compile()
   1157 if isinstance(mod, torch._dynamo.utils.GmWrapper):
   1158     # This function is called by the flatten_graph_inputs wrapper, which boxes
   1159     # the inputs so that they can be freed before the end of this scope.
   1160     # For overhead reasons, this is not the default wrapper, see comment:
   1161     # https://github.com/pytorch/pytorch/pull/122535/files#r1560096481
   1162     def boxed_forward(runtime_args: List[Any]):

File [my venv root dir]\gemma3\Lib\site-packages\torch\_functorch\aot_autograd.py:1131, in aot_module_simplified.<locals>.dispatch_and_compile()
   1129 functional_call = create_functional_call(mod, params_spec, params_len)
   1130 with compiled_autograd._disable():
-> 1131     compiled_fn, _ = create_aot_dispatcher_function(
   1132         functional_call,
   1133         fake_flat_args,
   1134         aot_config,
   1135         fake_mode,
   1136         shape_env,
   1137     )
   1138 return compiled_fn

File [my venv root dir]\gemma3\Lib\site-packages\torch\_functorch\aot_autograd.py:580, in create_aot_dispatcher_function(flat_fn, fake_flat_args, aot_config, fake_mode, shape_env)
    572 def create_aot_dispatcher_function(
    573     flat_fn,
    574     fake_flat_args: FakifiedFlatArgs,
   (...)    577     shape_env: Optional[ShapeEnv],
    578 ) -> Tuple[Callable, ViewAndMutationMeta]:
    579     with dynamo_timed("create_aot_dispatcher_function", log_pt2_compile_event=True):
--> 580         return _create_aot_dispatcher_function(
    581             flat_fn, fake_flat_args, aot_config, fake_mode, shape_env
    582         )

File [my venv root dir]\gemma3\Lib\site-packages\torch\_functorch\aot_autograd.py:830, in _create_aot_dispatcher_function(flat_fn, fake_flat_args, aot_config, fake_mode, shape_env)
    826         return aot_dispatch_base
    828 compiler_fn = choose_dispatcher(needs_autograd, aot_config)
--> 830 compiled_fn, fw_metadata = compiler_fn(
    831     flat_fn,
    832     _dup_fake_script_obj(fake_flat_args),
    833     aot_config,
    834     fw_metadata=fw_metadata,
    835 )
    836 return compiled_fn, fw_metadata

File [my venv root dir]\gemma3\Lib\site-packages\torch\_functorch\_aot_autograd\jit_compile_runtime_wrappers.py:203, in aot_dispatch_base(flat_fn, flat_args, aot_config, fw_metadata)
    201         assert isinstance(fw_module, GraphModule)
    202         tensorify_python_scalars(fw_module, fake_mode.shape_env, fake_mode)
--> 203     compiled_fw = compiler(fw_module, updated_flat_args)
    205 if fakified_out_wrapper.needs_post_compile:
    206     fakified_out_wrapper.set_fwd_output_strides(fwd_output_strides)

File [my venv root dir]\gemma3\Lib\site-packages\torch\_functorch\aot_autograd.py:489, in SerializableAOTDispatchCompiler.__call__(self, gm, example_inputs)
    484 def __call__(
    485     self,
    486     gm: torch.fx.GraphModule,
    487     example_inputs: Sequence[InputType],
    488 ) -> OutputCode:
--> 489     return self.compiler_fn(gm, example_inputs)

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\compile_fx.py:1741, in compile_fx.<locals>.fw_compiler_base(gm, example_inputs, is_inference)
   1738 else:
   1739     model_outputs_node.meta["user_visible_output_idxs"] = []
-> 1741 return inner_compile(
   1742     gm,
   1743     example_inputs,
   1744     static_input_idxs=get_static_input_idxs(fixed),
   1745     cudagraphs=cudagraphs,
   1746     graph_id=graph_id,
   1747     is_inference=is_inference,
   1748     boxed_forward_device_index=forward_device,
   1749 )

File C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.12_3.12.2800.0_x64__qbz5n2kfra8p0\Lib\contextlib.py:81, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     78 

@wraps
	(func)
     79 def inner(*args, **kwds):
     80     with self._recreate_cm():
---> 81         return func(*args, **kwds)

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\compile_fx.py:569, in compile_fx_inner(gm, example_inputs, **kwargs)
    562 stack.enter_context(DebugContext())
    564 get_chromium_event_logger().add_event_data(
    565     "inductor_compile",
    566     is_backward=kwargs["is_backward"],
    567 )
--> 569 return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
    570     gm,
    571     example_inputs,
    572     **kwargs,
    573 )

File [my venv root dir]\gemma3\Lib\site-packages\torch\_dynamo\repro\after_aot.py:102, in wrap_compiler_debug.<locals>.debug_wrapper(gm, example_inputs, **kwargs)
     97 assert config.repro_after in ("dynamo", "aot", None)
     99 try:
    100     # Call the compiler_fn - which is either aot_autograd or inductor
    101     # with fake inputs
--> 102     inner_compiled_fn = compiler_fn(gm, example_inputs)
    103 except Exception as e:
    104     # TODO: Failures here are troublesome because no real inputs,
    105     # need a different serialization strategy
    106     if config.repro_after == "aot":

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\compile_fx.py:685, in _compile_fx_inner(gm, example_inputs, **graph_kwargs)
    683 TritonBundler.begin_compile()
    684 try:
--> 685     mb_compiled_graph = fx_codegen_and_compile(
    686         gm, example_inputs, inputs_to_check, **graph_kwargs
    687     )
    688     assert mb_compiled_graph is not None
    689     mb_compiled_graph._time_taken_ns = time.time_ns() - start_time

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\compile_fx.py:1129, in fx_codegen_and_compile(gm, example_inputs, inputs_to_check, **graph_kwargs)
   1119 def fx_codegen_and_compile(
   1120     gm: GraphModule,
   1121     example_inputs: Sequence[InputType],
   (...)   1125     **graph_kwargs: Unpack[_CompileFxKwargs],
   1126 ) -> OutputCode:
   1127     scheme: FxCompile = _InProcessFxCompile()
-> 1129     return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\compile_fx.py:1044, in _InProcessFxCompile.codegen_and_compile(self, gm, example_inputs, inputs_to_check, graph_kwargs)
   1036             compiled_fn = AotCodeCompiler.compile(
   1037                 graph,
   1038                 code,
   (...)   1041                 additional_files=additional_files,
   1042             )
   1043     else:
-> 1044         compiled_fn = graph.compile_to_module().call
   1046 num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
   1047 metrics.num_bytes_accessed += num_bytes

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\graph.py:2027, in GraphLowering.compile_to_module(self)
   2020 def compile_to_module(self) -> ModuleType:
   2021     with dynamo_timed(
   2022         "GraphLowering.compile_to_module",
   2023         phase_name="code_gen",
   2024         log_pt2_compile_event=True,
   2025         dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us",
   2026     ):
-> 2027         return self._compile_to_module()

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\graph.py:2033, in GraphLowering._compile_to_module(self)
   2029 def _compile_to_module(self) -> ModuleType:
   2030     from .codecache import PyCodeCache
   2032     code, linemap = (
-> 2033         self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
   2034     )
   2035     if config.triton.autotune_at_compile_time:
   2036         tuning_code = (
   2037             '"""\n'
   2038             + "Compile-time auto-tuning block: \n"
   (...)   2041             + '"""\n'
   2042         )

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\graph.py:1964, in GraphLowering.codegen(self)
   1960 from .scheduler import Scheduler
   1962 self.init_wrapper_code()
-> 1964 self.scheduler = Scheduler(self.operations)
   1965 V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
   1967 self.wrapper_code.push_codegened_graph(self)

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\scheduler.py:1798, in Scheduler.__init__(self, nodes)
   1796 def __init__(self, nodes: List[ir.Operation]) -> None:
   1797     with dynamo_timed("Scheduler.__init__"):
-> 1798         self._init(nodes)

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\scheduler.py:1816, in Scheduler._init(self, nodes)
   1807 self.completed_operations: OrderedSet[str] = OrderedSet()
   1808 self.available_buffer_names = OrderedSet(
   1809     [
   1810         *V.graph.graph_inputs.keys(),
   (...)   1813     ]
   1814 )
-> 1816 self.nodes = [self.create_scheduler_node(n) for n in nodes]
   1817 self.update_zero_dim_cpu_tensor()
   1818 # some new constants could have been created above

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\scheduler.py:1947, in Scheduler.create_scheduler_node(self, node)
   1945     return NopKernelSchedulerNode(self, node)
   1946 elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)):
-> 1947     return SchedulerNode(self, node)
   1948 elif isinstance(node, ir.ExternKernel):
   1949     return ExternKernelSchedulerNode(self, node)

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\scheduler.py:893, in SchedulerNode.__init__(self, scheduler, node)
    891 super().__init__(scheduler)
    892 self._init_from_node(node)
--> 893 self._compute_attrs()

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\scheduler.py:907, in SchedulerNode._compute_attrs(self, extra_indexing_constraints, recompute_sizes_body_func)
    901 self._sizes, self._body = self.node.simplify_and_reorder(
    902     extra_indexing_constraints=extra_indexing_constraints,
    903     recompute_sizes_body_func=recompute_sizes_body_func,
    904 )
    906 device = self.node.get_device_or_error()
--> 907 group_fn = self.scheduler.get_backend(device).group_fn
    908 self.group = (device, group_fn(self._sizes))
    910 # Don't normalize since normalization will merge loops which
    911 # makes it hard to decide new loop orders.

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\scheduler.py:3441, in Scheduler.get_backend(self, device)
   3439 assert device is not None
   3440 if device not in self.backends:
-> 3441     self.backends[device] = self.create_backend(device)
   3442 return self.backends[device]

File [my venv root dir]\gemma3\Lib\site-packages\torch\_inductor\scheduler.py:3432, in Scheduler.create_backend(self, device)
   3428         raise RuntimeError(
   3429             f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}"  # noqa: B950
   3430         )
   3431     elif is_gpu(device.type):
-> 3432         raise RuntimeError(
   3433             "Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at https://github.com/openai/triton"  # noqa: B950
   3434         )
   3436 return device_scheduling(self)

BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Cannot find a working triton installation. Either the package is not installed or it is too old. More information on installing Triton can be found at https://github.com/openai/triton

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Python environment:

Package                   Version
------------------------- --------------
accelerate                1.6.0
anyio                     4.9.0
argon2-cffi               23.1.0
argon2-cffi-bindings      21.2.0
arrow                     1.3.0
asttokens                 3.0.0
async-lru                 2.0.5
attrs                     25.3.0
babel                     2.17.0
beautifulsoup4            4.13.4
bleach                    6.2.0
certifi                   2025.1.31
cffi                      1.17.1
charset-normalizer        3.4.1
colorama                  0.4.6
comm                      0.2.2
debugpy                   1.8.14
decorator                 5.2.1
defusedxml                0.7.1
executing                 2.2.0
fastjsonschema            2.21.1
filelock                  3.13.1
fqdn                      1.5.1
fsspec                    2024.6.1
h11                       0.14.0
httpcore                  1.0.8
httpx                     0.28.1
huggingface-hub           0.30.2
idna                      3.10
ipykernel                 6.29.5
ipython                   9.1.0
ipython_pygments_lexers   1.1.1
ipywidgets                8.1.6
isoduration               20.11.0
jedi                      0.19.2
Jinja2                    3.1.4
json5                     0.12.0
jsonpointer               3.0.0
jsonschema                4.23.0
jsonschema-specifications 2024.10.1
jupyter_client            8.6.3
jupyter_core              5.7.2
jupyter-events            0.12.0
jupyter-lsp               2.2.5
jupyter_server            2.15.0
jupyter_server_terminals  0.5.3
jupyterlab                4.4.0
jupyterlab_pygments       0.3.0
jupyterlab_server         2.27.3
jupyterlab_widgets        3.0.14
kaleido                   0.1.0.post1
MarkupSafe                2.1.5
matplotlib-inline         0.1.7
mistune                   3.1.3
mpmath                    1.3.0
narwhals                  1.35.0
nbclient                  0.10.2
nbconvert                 7.16.6
nbformat                  5.10.4
nest-asyncio              1.6.0
networkx                  3.3
notebook                  7.4.0
notebook_shim             0.2.4
numpy                     2.1.2
overrides                 7.7.0
packaging                 24.2
pandas                    2.2.3
pandocfilters             1.5.1
parso                     0.8.4
pillow                    11.0.0
pip                       25.0.1
platformdirs              4.3.7
plotly                    6.0.1
prometheus_client         0.21.1
prompt_toolkit            3.0.51
psutil                    7.0.0
pure_eval                 0.2.3
pycparser                 2.22
Pygments                  2.19.1
python-dateutil           2.9.0.post0
python-json-logger        3.3.0
pytz                      2025.2
pywin32                   310
pywinpty                  2.0.15
PyYAML                    6.0.2
pyzmq                     26.4.0
referencing               0.36.2
regex                     2024.11.6
requests                  2.32.3
rfc3339-validator         0.1.4
rfc3986-validator         0.1.1
rpds-py                   0.24.0
safetensors               0.5.3
Send2Trash                1.8.3
setuptools                70.2.0
six                       1.17.0
sniffio                   1.3.1
soupsieve                 2.6
stack-data                0.6.3
sympy                     1.13.1
terminado                 0.18.1
tinycss2                  1.4.0
tokenizers                0.21.1
torch                     2.6.0+cu126
torchaudio                2.6.0+cu126
torchvision               0.21.0+cu126
tornado                   6.4.2
tqdm                      4.67.1
traitlets                 5.14.3
transformers              4.51.3
types-python-dateutil     2.9.0.20241206
typing_extensions         4.12.2
tzdata                    2025.2
uri-template              1.3.0
urllib3                   2.4.0
wcwidth                   0.2.13
webcolors                 24.11.1
webencodings              0.5.1
websocket-client          1.8.0
widgetsnbextension        4.0.14
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment