\n",
+ "Labels shape: torch.Size([2, 512])\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [39/39 02:15, Epoch 3/3]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Step | \n",
+ " Training Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 5 | \n",
+ " 12.130200 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 3.432800 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.502100 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.297100 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.232200 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 0.199000 | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " 0.174700 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "TrainOutput(global_step=39, training_loss=2.1929761950786295, metrics={'train_runtime': 140.2841, 'train_samples_per_second': 1.112, 'train_steps_per_second': 0.278, 'total_flos': 3409289020440576.0, 'train_loss': 2.1929761950786295, 'epoch': 3.0})"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from transformers import TrainingArguments, Trainer\n",
+ "\n",
+ "# 1. Configurar formato del dataset como tensores\n",
+ "tokenized_dataset.set_format(\"torch\", columns=[\"input_ids\", \"attention_mask\"])\n",
+ "\n",
+ "# 2. Data collator mejorado\n",
+ "def custom_collator(features):\n",
+ " return {\n",
+ " \"input_ids\": torch.stack([torch.tensor(f[\"input_ids\"]) for f in features]),\n",
+ " \"attention_mask\": torch.stack([torch.tensor(f[\"attention_mask\"]) for f in features]),\n",
+ " \"labels\": torch.stack([torch.tensor(f[\"input_ids\"]) for f in features])\n",
+ " }\n",
+ "\n",
+ "# 3. Configurar argumentos con parámetros faltantes\n",
+ "training_args = TrainingArguments(\n",
+ " output_dir=\"./html5-lora\",\n",
+ " per_device_train_batch_size=2,\n",
+ " gradient_accumulation_steps=2, # Reducir para ahorrar memoria\n",
+ " num_train_epochs=3,\n",
+ " learning_rate=3e-4,\n",
+ " fp16=torch.cuda.is_available(),\n",
+ " logging_steps=5,\n",
+ " report_to=\"none\",\n",
+ " remove_unused_columns=False, # Necesario para LoRA\n",
+ " label_names=[\"labels\"] # Añadir parámetro faltante\n",
+ ")\n",
+ "\n",
+ "# 4. Crear Trainer con parámetros actualizados\n",
+ "trainer = Trainer(\n",
+ " model=model,\n",
+ " args=training_args,\n",
+ " train_dataset=tokenized_dataset[\"train\"],\n",
+ " eval_dataset=tokenized_dataset[\"test\"],\n",
+ " data_collator=custom_collator\n",
+ ")\n",
+ "\n",
+ "# 5. Verificación adicional\n",
+ "sample_batch = next(iter(trainer.get_train_dataloader()))\n",
+ "print(\"\\nVerificación de batch:\")\n",
+ "print(f\"Input ids type: {type(sample_batch['input_ids'][0])}\")\n",
+ "print(f\"Labels shape: {sample_batch['labels'].shape}\")\n",
+ "\n",
+ "# 6. Iniciar entrenamiento\n",
+ "trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hm89m0JCtYnY"
+ },
+ "source": [
+ "### Generación de Respuestas"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "rukjNbmftfCv",
+ "outputId": "e7c3781f-1a33-4a43-9c8c-4eb1e68589e2"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Device set to use cuda:0\n",
+ "The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['AriaTextForCausalLM', 'BambaForCausalLM', 'BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'Cohere2ForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'DiffLlamaForCausalLM', 'ElectraForCausalLM', 'Emu3ForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'GotOcr2ForConditionalGeneration', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'GraniteForCausalLM', 'GraniteMoeForCausalLM', 'GraniteMoeSharedForCausalLM', 'HeliumForCausalLM', 'JambaForCausalLM', 'JetMoeForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'Mamba2ForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MllamaForCausalLM', 'MoshiForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForCausalLM', 'MvpForCausalLM', 'NemotronForCausalLM', 'OlmoForCausalLM', 'Olmo2ForCausalLM', 'OlmoeForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'Phi3ForCausalLM', 'PhimoeForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RecurrentGemmaForCausalLM', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'WhisperForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM', 'ZambaForCausalLM', 'Zamba2ForCausalLM'].\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformers import pipeline\n",
+ "chatbot = pipeline(\n",
+ " \"text-generation\",\n",
+ " model = model,\n",
+ " tokenizer = tokenizer,\n",
+ " torch_dtype = torch.float16\n",
+ ")\n",
+ "\n",
+ "def generate_response(query):\n",
+ " prompt = f\"[INST] Pregunta HTML5: {query} [/INST]\"\n",
+ " response = chatbot(\n",
+ " prompt,\n",
+ " max_new_tokens = 200,\n",
+ " temperature = 0.3,\n",
+ " do_sample = True,\n",
+ " pad_token_id = tokenizer.eos_token_id\n",
+ " )\n",
+ " return response[0]['generated_text'].split(\"[/INST]\")[-1].strip()\n",
+ "\n",
+ "\n",
+ "\n",
+ "def generate_response_gradio(query):\n",
+ " try:\n",
+ " # Manejar casos no técnicos primero\n",
+ " if query.lower().strip() in [\"hola\", \"hi\", \"ayuda\"]:\n",
+ " return \"¡Hola! Soy un asistente de HTML5. Ejemplo: '¿Cómo usar