Skip to content
Home / Skills / Ai Ml / Fine-Tuning
AI

Fine-Tuning

Ai Ml advanced v1.0.0

Fine-Tuning

Overview

Fine-tuning adapts pre-trained models to specific domains or tasks. This skill covers when to fine-tune vs prompt, parameter-efficient methods like LoRA, and training best practices.


Key Concepts

Fine-Tuning Decision Tree

┌─────────────────────────────────────────────────────────────┐
│             When to Fine-Tune vs Prompt                      │
├─────────────────────────────────────────────────────────────┤
│                                                              │
│  Start Here: Can prompting solve your problem?              │
│       │                                                      │
│       ├─── YES → Use prompting (cheaper, faster)            │
│       │                                                      │
│       └─── NO → Consider fine-tuning if:                    │
│            │                                                 │
│            ├── Consistent output format needed              │
│            ├── Domain-specific terminology/style            │
│            ├── Need to reduce token usage                   │
│            ├── Latency requirements                         │
│            └── Proprietary knowledge integration            │
│                                                              │
│  Fine-Tuning Methods:                                       │
│  ┌─────────────────────────────────────────────────────┐   │
│  │                                                      │   │
│  │  Full Fine-Tuning                                   │   │
│  │  ├── Updates all parameters                         │   │
│  │  ├── Highest quality, most expensive               │   │
│  │  └── Risk of catastrophic forgetting               │   │
│  │                                                      │   │
│  │  LoRA (Low-Rank Adaptation)                         │   │
│  │  ├── Trains small adapter matrices                  │   │
│  │  ├── <1% of parameters updated                      │   │
│  │  └── Can swap adapters at inference                 │   │
│  │                                                      │   │
│  │  QLoRA                                              │   │
│  │  ├── LoRA with 4-bit quantization                   │   │
│  │  └── Enables fine-tuning on consumer hardware       │   │
│  │                                                      │   │
│  │  Prefix Tuning                                      │   │
│  │  ├── Adds trainable tokens to input                 │   │
│  │  └── Very parameter efficient                       │   │
│  │                                                      │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                              │
└─────────────────────────────────────────────────────────────┘

Best Practices

1. Start with High-Quality Data

Data quality matters more than quantity.

2. Use Validation Set

Monitor for overfitting.

3. Try LoRA First

Much cheaper and often effective.

4. Evaluate Thoroughly

Test on held-out data and edge cases.

5. Version Models

Track datasets, hyperparameters, and metrics.


Code Examples

Example 1: OpenAI Fine-Tuning

@Service
public class OpenAiFineTuningService {
    
    private final WebClient webClient;
    private final ObjectMapper objectMapper;
    
    /**
     * Prepare training data in JSONL format
     */
    public File prepareTrainingData(List<TrainingExample> examples, String outputPath) 
            throws IOException {
        
        File file = new File(outputPath);
        try (BufferedWriter writer = new BufferedWriter(new FileWriter(file))) {
            for (TrainingExample example : examples) {
                Map<String, Object> jsonlRecord = Map.of(
                    "messages", List.of(
                        Map.of("role", "system", "content", example.getSystemPrompt()),
                        Map.of("role", "user", "content", example.getUserMessage()),
                        Map.of("role", "assistant", "content", example.getAssistantResponse())
                    )
                );
                writer.write(objectMapper.writeValueAsString(jsonlRecord));
                writer.newLine();
            }
        }
        
        return file;
    }
    
    /**
     * Upload training file
     */
    public String uploadTrainingFile(File file) {
        MultipartBodyBuilder builder = new MultipartBodyBuilder();
        builder.part("file", new FileSystemResource(file));
        builder.part("purpose", "fine-tune");
        
        FileUploadResponse response = webClient.post()
            .uri("/v1/files")
            .contentType(MediaType.MULTIPART_FORM_DATA)
            .body(BodyInserters.fromMultipartData(builder.build()))
            .retrieve()
            .bodyToMono(FileUploadResponse.class)
            .block();
        
        return response.getId();
    }
    
    /**
     * Create fine-tuning job
     */
    public FineTuningJob createFineTuningJob(FineTuningConfig config) {
        Map<String, Object> request = new HashMap<>();
        request.put("training_file", config.getTrainingFileId());
        request.put("model", config.getBaseModel());  // e.g., "gpt-4o-mini-2024-07-18"
        
        // Optional hyperparameters
        if (config.getEpochs() != null || config.getLearningRateMultiplier() != null) {
            Map<String, Object> hyperparams = new HashMap<>();
            if (config.getEpochs() != null) {
                hyperparams.put("n_epochs", config.getEpochs());
            }
            if (config.getLearningRateMultiplier() != null) {
                hyperparams.put("learning_rate_multiplier", config.getLearningRateMultiplier());
            }
            request.put("hyperparameters", hyperparams);
        }
        
        if (config.getValidationFileId() != null) {
            request.put("validation_file", config.getValidationFileId());
        }
        
        if (config.getSuffix() != null) {
            request.put("suffix", config.getSuffix());  // Custom model name suffix
        }
        
        return webClient.post()
            .uri("/v1/fine_tuning/jobs")
            .bodyValue(request)
            .retrieve()
            .bodyToMono(FineTuningJob.class)
            .block();
    }
    
    /**
     * Monitor fine-tuning job
     */
    public FineTuningJob getJobStatus(String jobId) {
        return webClient.get()
            .uri("/v1/fine_tuning/jobs/{id}", jobId)
            .retrieve()
            .bodyToMono(FineTuningJob.class)
            .block();
    }
    
    /**
     * List job events (training progress)
     */
    public List<FineTuningEvent> getJobEvents(String jobId) {
        FineTuningEventsResponse response = webClient.get()
            .uri("/v1/fine_tuning/jobs/{id}/events", jobId)
            .retrieve()
            .bodyToMono(FineTuningEventsResponse.class)
            .block();
        
        return response.getData();
    }
    
    /**
     * Use fine-tuned model
     */
    public CompletionResponse complete(String fineTunedModel, String prompt) {
        return llmClient.complete(
            CompletionRequest.builder()
                .model(fineTunedModel)  // e.g., "ft:gpt-4o-mini:company::abc123"
                .messages(List.of(Message.user(prompt)))
                .build()
        );
    }
}

@Data
@Builder
class FineTuningConfig {
    private String trainingFileId;
    private String validationFileId;
    private String baseModel;
    private String suffix;
    private Integer epochs;
    private Double learningRateMultiplier;
}

@Data
class FineTuningJob {
    private String id;
    private String status;  // validating_files, queued, running, succeeded, failed
    private String fineTunedModel;
    private Instant createdAt;
    private Instant finishedAt;
    private Error error;
}

Example 2: Training Data Preparation

@Service
public class TrainingDataService {
    
    /**
     * Create training examples from historical data
     */
    public List<TrainingExample> prepareExamples(List<ConversationLog> logs) {
        return logs.stream()
            .filter(this::isHighQuality)
            .map(this::toTrainingExample)
            .toList();
    }
    
    private boolean isHighQuality(ConversationLog log) {
        // Filter criteria
        return log.getUserRating() >= 4.0 &&              // User rated positively
               log.getResponseTime() < 5000 &&            // Fast response
               log.getAssistantMessage().length() > 50 && // Substantive response
               !log.isEdited();                            // Not manually corrected
    }
    
    private TrainingExample toTrainingExample(ConversationLog log) {
        return TrainingExample.builder()
            .systemPrompt(buildSystemPrompt(log.getContext()))
            .userMessage(log.getUserMessage())
            .assistantResponse(log.getAssistantMessage())
            .build();
    }
    
    /**
     * Split data for training/validation
     */
    public DataSplit splitData(List<TrainingExample> examples, double validationRatio) {
        Collections.shuffle(examples);
        
        int validationSize = (int) (examples.size() * validationRatio);
        
        return new DataSplit(
            examples.subList(validationSize, examples.size()),  // Training
            examples.subList(0, validationSize)                  // Validation
        );
    }
    
    /**
     * Augment training data
     */
    public List<TrainingExample> augmentData(List<TrainingExample> examples) {
        List<TrainingExample> augmented = new ArrayList<>(examples);
        
        for (TrainingExample example : examples) {
            // Paraphrase user messages
            String paraphrasedUser = paraphraseService.paraphrase(example.getUserMessage());
            augmented.add(example.withUserMessage(paraphrasedUser));
            
            // Add typos/variations for robustness
            String withTypos = addRealisticVariations(example.getUserMessage());
            augmented.add(example.withUserMessage(withTypos));
        }
        
        return augmented;
    }
    
    /**
     * Validate training data format
     */
    public ValidationResult validateTrainingData(List<TrainingExample> examples) {
        List<String> errors = new ArrayList<>();
        List<String> warnings = new ArrayList<>();
        
        // Check minimum size
        if (examples.size() < 10) {
            errors.add("Need at least 10 examples, got " + examples.size());
        }
        
        // Check for duplicates
        Set<String> seen = new HashSet<>();
        for (TrainingExample ex : examples) {
            String key = ex.getUserMessage() + "|||" + ex.getAssistantResponse();
            if (!seen.add(key)) {
                warnings.add("Duplicate example found");
            }
        }
        
        // Check token counts
        for (int i = 0; i < examples.size(); i++) {
            TrainingExample ex = examples.get(i);
            int tokens = countTokens(ex);
            if (tokens > 16000) {
                errors.add("Example " + i + " exceeds token limit: " + tokens);
            }
        }
        
        // Check for data leakage
        checkDataLeakage(examples, warnings);
        
        return new ValidationResult(errors.isEmpty(), errors, warnings);
    }
    
    record DataSplit(List<TrainingExample> training, List<TrainingExample> validation) {}
}

Example 3: LoRA Fine-Tuning (Python/Conceptual)

/**
 * This example shows the configuration for LoRA fine-tuning.
 * Actual training typically runs in Python with HuggingFace.
 */
@Service
public class LoraConfigService {
    
    /**
     * Generate LoRA training configuration
     */
    public LoraTrainingConfig generateConfig(LoraRequest request) {
        return LoraTrainingConfig.builder()
            // LoRA specific parameters
            .loraR(16)                        // Rank of update matrices
            .loraAlpha(32)                    // Scaling factor
            .loraDropout(0.1)                 // Dropout for regularization
            .targetModules(List.of(           // Which layers to adapt
                "q_proj", "k_proj", "v_proj", "o_proj",  // Attention
                "gate_proj", "up_proj", "down_proj"       // MLP
            ))
            
            // Training parameters
            .baseModel(request.getBaseModel())
            .learningRate(2e-4)
            .batchSize(4)
            .gradientAccumulationSteps(4)     // Effective batch = 16
            .epochs(3)
            .warmupRatio(0.03)
            .weightDecay(0.01)
            
            // Quantization for QLoRA
            .quantization(request.isUseQlora() ? "4bit" : null)
            .bnbConfig(request.isUseQlora() ? BnbConfig.builder()
                .loadIn4bit(true)
                .bnb4bitQuant("nf4")
                .bnb4bitComputeDtype("bfloat16")
                .bnb4bitUseDoubleQuant(true)
                .build() : null)
            
            // Output
            .outputDir(request.getOutputPath())
            .loggingSteps(10)
            .saveSteps(500)
            .evalSteps(500)
            
            .build();
    }
    
    /**
     * Generate Python training script
     */
    public String generateTrainingScript(LoraTrainingConfig config) {
        return """
            from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
            from peft import LoraConfig, get_peft_model
            from trl import SFTTrainer
            import torch
            
            # Load base model
            model = AutoModelForCausalLM.from_pretrained(
                "%s",
                torch_dtype=torch.bfloat16,
                device_map="auto",
                %s
            )
            
            tokenizer = AutoTokenizer.from_pretrained("%s")
            tokenizer.pad_token = tokenizer.eos_token
            
            # Configure LoRA
            lora_config = LoraConfig(
                r=%d,
                lora_alpha=%d,
                lora_dropout=%f,
                target_modules=%s,
                task_type="CAUSAL_LM"
            )
            
            model = get_peft_model(model, lora_config)
            model.print_trainable_parameters()
            
            # Training arguments
            training_args = TrainingArguments(
                output_dir="%s",
                per_device_train_batch_size=%d,
                gradient_accumulation_steps=%d,
                learning_rate=%e,
                num_train_epochs=%d,
                warmup_ratio=%f,
                weight_decay=%f,
                logging_steps=%d,
                save_steps=%d,
                eval_steps=%d,
                bf16=True,
            )
            
            # Create trainer
            trainer = SFTTrainer(
                model=model,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                args=training_args,
                tokenizer=tokenizer,
                max_seq_length=2048,
            )
            
            # Train
            trainer.train()
            
            # Save LoRA adapter
            model.save_pretrained("%s/lora_adapter")
            """.formatted(
                config.getBaseModel(),
                config.getBnbConfig() != null ? generateBnbConfig(config.getBnbConfig()) : "",
                config.getBaseModel(),
                config.getLoraR(),
                config.getLoraAlpha(),
                config.getLoraDropout(),
                config.getTargetModules(),
                config.getOutputDir(),
                config.getBatchSize(),
                config.getGradientAccumulationSteps(),
                config.getLearningRate(),
                config.getEpochs(),
                config.getWarmupRatio(),
                config.getWeightDecay(),
                config.getLoggingSteps(),
                config.getSaveSteps(),
                config.getEvalSteps(),
                config.getOutputDir()
            );
    }
}

Example 4: Model Evaluation

@Service
public class FineTunedModelEvaluator {
    
    private final LlmClient llmClient;
    
    /**
     * Evaluate fine-tuned model against test set
     */
    public EvaluationReport evaluate(String modelId, List<TestCase> testCases) {
        List<EvaluationResult> results = new ArrayList<>();
        
        for (TestCase testCase : testCases) {
            // Get model response
            CompletionResponse response = llmClient.complete(
                CompletionRequest.builder()
                    .model(modelId)
                    .messages(List.of(
                        Message.system(testCase.getSystemPrompt()),
                        Message.user(testCase.getInput())
                    ))
                    .temperature(0.0)  // Deterministic for evaluation
                    .build()
            );
            
            String generated = response.getContent();
            String expected = testCase.getExpectedOutput();
            
            // Calculate metrics
            double exactMatch = generated.trim().equals(expected.trim()) ? 1.0 : 0.0;
            double similarity = calculateSimilarity(generated, expected);
            double formatCompliance = checkFormatCompliance(generated, testCase.getExpectedFormat());
            
            results.add(EvaluationResult.builder()
                .testCaseId(testCase.getId())
                .input(testCase.getInput())
                .expected(expected)
                .generated(generated)
                .exactMatch(exactMatch)
                .similarity(similarity)
                .formatCompliance(formatCompliance)
                .build());
        }
        
        return EvaluationReport.builder()
            .modelId(modelId)
            .totalTestCases(testCases.size())
            .avgExactMatch(average(results, EvaluationResult::getExactMatch))
            .avgSimilarity(average(results, EvaluationResult::getSimilarity))
            .avgFormatCompliance(average(results, EvaluationResult::getFormatCompliance))
            .results(results)
            .build();
    }
    
    /**
     * Compare fine-tuned model with base model
     */
    public ComparisonReport compareWithBase(
            String fineTunedModel, 
            String baseModel,
            List<TestCase> testCases) {
        
        EvaluationReport fineTunedResults = evaluate(fineTunedModel, testCases);
        EvaluationReport baseResults = evaluate(baseModel, testCases);
        
        return ComparisonReport.builder()
            .fineTunedModel(fineTunedModel)
            .baseModel(baseModel)
            .fineTunedMetrics(fineTunedResults)
            .baseMetrics(baseResults)
            .exactMatchImprovement(
                fineTunedResults.getAvgExactMatch() - baseResults.getAvgExactMatch())
            .similarityImprovement(
                fineTunedResults.getAvgSimilarity() - baseResults.getAvgSimilarity())
            .build();
    }
    
    /**
     * Test for regression on general capabilities
     */
    public RegressionReport checkRegression(
            String fineTunedModel,
            String baseModel,
            List<GeneralCapabilityTest> generalTests) {
        
        List<RegressionResult> results = new ArrayList<>();
        
        for (GeneralCapabilityTest test : generalTests) {
            // Test both models
            CompletionResponse fineTunedResponse = runTest(fineTunedModel, test);
            CompletionResponse baseResponse = runTest(baseModel, test);
            
            // Evaluate responses
            double fineTunedScore = evaluateGeneral(fineTunedResponse, test);
            double baseScore = evaluateGeneral(baseResponse, test);
            
            results.add(new RegressionResult(
                test.getCapability(),
                baseScore,
                fineTunedScore,
                fineTunedScore >= baseScore * 0.95  // Allow 5% degradation
            ));
        }
        
        boolean hasRegression = results.stream().anyMatch(r -> !r.passed());
        
        return new RegressionReport(results, !hasRegression);
    }
}

Example 5: Model Versioning

@Service
public class ModelVersioningService {
    
    private final ModelRepository modelRepository;
    private final ArtifactStorage artifactStorage;
    
    /**
     * Register a new fine-tuned model version
     */
    public ModelVersion registerModel(RegisterModelRequest request) {
        ModelVersion version = ModelVersion.builder()
            .id(UUID.randomUUID().toString())
            .name(request.getModelName())
            .version(calculateNextVersion(request.getModelName()))
            .baseModel(request.getBaseModel())
            .fineTunedModelId(request.getFineTunedModelId())
            .trainingConfig(request.getTrainingConfig())
            .trainingMetrics(request.getTrainingMetrics())
            .evaluationReport(request.getEvaluationReport())
            .datasetVersion(request.getDatasetVersion())
            .createdAt(Instant.now())
            .createdBy(request.getCreatedBy())
            .status(ModelStatus.STAGING)
            .build();
        
        // Store training artifacts
        if (request.getLoraAdapter() != null) {
            String adapterPath = artifactStorage.store(
                request.getLoraAdapter(),
                "models/" + version.getId() + "/adapter"
            );
            version.setAdapterPath(adapterPath);
        }
        
        modelRepository.save(version);
        
        return version;
    }
    
    /**
     * Promote model to production
     */
    public void promoteToProduction(String modelId, String reason) {
        ModelVersion model = modelRepository.findById(modelId)
            .orElseThrow(() -> new NotFoundException("Model not found"));
        
        // Demote current production model
        modelRepository.findByNameAndStatus(model.getName(), ModelStatus.PRODUCTION)
            .ifPresent(current -> {
                current.setStatus(ModelStatus.ARCHIVED);
                modelRepository.save(current);
            });
        
        // Promote new model
        model.setStatus(ModelStatus.PRODUCTION);
        model.setPromotedAt(Instant.now());
        model.setPromotionReason(reason);
        modelRepository.save(model);
        
        // Update model router
        modelRouter.setActiveModel(model.getName(), model.getFineTunedModelId());
    }
    
    /**
     * A/B test models
     */
    public ABTestConfig createABTest(ABTestRequest request) {
        return ABTestConfig.builder()
            .id(UUID.randomUUID().toString())
            .name(request.getName())
            .controlModel(request.getControlModelId())
            .treatmentModel(request.getTreatmentModelId())
            .trafficSplit(request.getTrafficSplit())  // e.g., 0.1 = 10% to treatment
            .metrics(request.getMetrics())
            .startDate(request.getStartDate())
            .endDate(request.getEndDate())
            .status(ABTestStatus.RUNNING)
            .build();
    }
    
    /**
     * Get model for inference with A/B routing
     */
    public String getModelForRequest(String modelName, String requestId) {
        // Check for active A/B test
        Optional<ABTestConfig> activeTest = abTestRepository
            .findActiveByModelName(modelName);
        
        if (activeTest.isPresent()) {
            ABTestConfig test = activeTest.get();
            // Deterministic routing based on request ID
            double hash = Math.abs(requestId.hashCode()) / (double) Integer.MAX_VALUE;
            
            String selectedModel = hash < test.getTrafficSplit() 
                ? test.getTreatmentModel() 
                : test.getControlModel();
            
            // Log for analysis
            logABTestAssignment(test.getId(), requestId, selectedModel);
            
            return selectedModel;
        }
        
        // Return production model
        return modelRepository.findByNameAndStatus(modelName, ModelStatus.PRODUCTION)
            .map(ModelVersion::getFineTunedModelId)
            .orElseThrow(() -> new NotFoundException("No production model for: " + modelName));
    }
}

Anti-Patterns

❌ Fine-Tuning Without Evaluation

Always measure performance before and after.

❌ Overfitting on Small Dataset

More data diversity, not just more epochs.


References