AI
Fine-Tuning
Ai Ml advanced v1.0.0
Fine-Tuning
Overview
Fine-tuning adapts pre-trained models to specific domains or tasks. This skill covers when to fine-tune vs prompt, parameter-efficient methods like LoRA, and training best practices.
Key Concepts
Fine-Tuning Decision Tree
┌─────────────────────────────────────────────────────────────┐
│ When to Fine-Tune vs Prompt │
├─────────────────────────────────────────────────────────────┤
│ │
│ Start Here: Can prompting solve your problem? │
│ │ │
│ ├─── YES → Use prompting (cheaper, faster) │
│ │ │
│ └─── NO → Consider fine-tuning if: │
│ │ │
│ ├── Consistent output format needed │
│ ├── Domain-specific terminology/style │
│ ├── Need to reduce token usage │
│ ├── Latency requirements │
│ └── Proprietary knowledge integration │
│ │
│ Fine-Tuning Methods: │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ Full Fine-Tuning │ │
│ │ ├── Updates all parameters │ │
│ │ ├── Highest quality, most expensive │ │
│ │ └── Risk of catastrophic forgetting │ │
│ │ │ │
│ │ LoRA (Low-Rank Adaptation) │ │
│ │ ├── Trains small adapter matrices │ │
│ │ ├── <1% of parameters updated │ │
│ │ └── Can swap adapters at inference │ │
│ │ │ │
│ │ QLoRA │ │
│ │ ├── LoRA with 4-bit quantization │ │
│ │ └── Enables fine-tuning on consumer hardware │ │
│ │ │ │
│ │ Prefix Tuning │ │
│ │ ├── Adds trainable tokens to input │ │
│ │ └── Very parameter efficient │ │
│ │ │ │
│ └─────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
Best Practices
1. Start with High-Quality Data
Data quality matters more than quantity.
2. Use Validation Set
Monitor for overfitting.
3. Try LoRA First
Much cheaper and often effective.
4. Evaluate Thoroughly
Test on held-out data and edge cases.
5. Version Models
Track datasets, hyperparameters, and metrics.
Code Examples
Example 1: OpenAI Fine-Tuning
@Service
public class OpenAiFineTuningService {
private final WebClient webClient;
private final ObjectMapper objectMapper;
/**
* Prepare training data in JSONL format
*/
public File prepareTrainingData(List<TrainingExample> examples, String outputPath)
throws IOException {
File file = new File(outputPath);
try (BufferedWriter writer = new BufferedWriter(new FileWriter(file))) {
for (TrainingExample example : examples) {
Map<String, Object> jsonlRecord = Map.of(
"messages", List.of(
Map.of("role", "system", "content", example.getSystemPrompt()),
Map.of("role", "user", "content", example.getUserMessage()),
Map.of("role", "assistant", "content", example.getAssistantResponse())
)
);
writer.write(objectMapper.writeValueAsString(jsonlRecord));
writer.newLine();
}
}
return file;
}
/**
* Upload training file
*/
public String uploadTrainingFile(File file) {
MultipartBodyBuilder builder = new MultipartBodyBuilder();
builder.part("file", new FileSystemResource(file));
builder.part("purpose", "fine-tune");
FileUploadResponse response = webClient.post()
.uri("/v1/files")
.contentType(MediaType.MULTIPART_FORM_DATA)
.body(BodyInserters.fromMultipartData(builder.build()))
.retrieve()
.bodyToMono(FileUploadResponse.class)
.block();
return response.getId();
}
/**
* Create fine-tuning job
*/
public FineTuningJob createFineTuningJob(FineTuningConfig config) {
Map<String, Object> request = new HashMap<>();
request.put("training_file", config.getTrainingFileId());
request.put("model", config.getBaseModel()); // e.g., "gpt-4o-mini-2024-07-18"
// Optional hyperparameters
if (config.getEpochs() != null || config.getLearningRateMultiplier() != null) {
Map<String, Object> hyperparams = new HashMap<>();
if (config.getEpochs() != null) {
hyperparams.put("n_epochs", config.getEpochs());
}
if (config.getLearningRateMultiplier() != null) {
hyperparams.put("learning_rate_multiplier", config.getLearningRateMultiplier());
}
request.put("hyperparameters", hyperparams);
}
if (config.getValidationFileId() != null) {
request.put("validation_file", config.getValidationFileId());
}
if (config.getSuffix() != null) {
request.put("suffix", config.getSuffix()); // Custom model name suffix
}
return webClient.post()
.uri("/v1/fine_tuning/jobs")
.bodyValue(request)
.retrieve()
.bodyToMono(FineTuningJob.class)
.block();
}
/**
* Monitor fine-tuning job
*/
public FineTuningJob getJobStatus(String jobId) {
return webClient.get()
.uri("/v1/fine_tuning/jobs/{id}", jobId)
.retrieve()
.bodyToMono(FineTuningJob.class)
.block();
}
/**
* List job events (training progress)
*/
public List<FineTuningEvent> getJobEvents(String jobId) {
FineTuningEventsResponse response = webClient.get()
.uri("/v1/fine_tuning/jobs/{id}/events", jobId)
.retrieve()
.bodyToMono(FineTuningEventsResponse.class)
.block();
return response.getData();
}
/**
* Use fine-tuned model
*/
public CompletionResponse complete(String fineTunedModel, String prompt) {
return llmClient.complete(
CompletionRequest.builder()
.model(fineTunedModel) // e.g., "ft:gpt-4o-mini:company::abc123"
.messages(List.of(Message.user(prompt)))
.build()
);
}
}
@Data
@Builder
class FineTuningConfig {
private String trainingFileId;
private String validationFileId;
private String baseModel;
private String suffix;
private Integer epochs;
private Double learningRateMultiplier;
}
@Data
class FineTuningJob {
private String id;
private String status; // validating_files, queued, running, succeeded, failed
private String fineTunedModel;
private Instant createdAt;
private Instant finishedAt;
private Error error;
}
Example 2: Training Data Preparation
@Service
public class TrainingDataService {
/**
* Create training examples from historical data
*/
public List<TrainingExample> prepareExamples(List<ConversationLog> logs) {
return logs.stream()
.filter(this::isHighQuality)
.map(this::toTrainingExample)
.toList();
}
private boolean isHighQuality(ConversationLog log) {
// Filter criteria
return log.getUserRating() >= 4.0 && // User rated positively
log.getResponseTime() < 5000 && // Fast response
log.getAssistantMessage().length() > 50 && // Substantive response
!log.isEdited(); // Not manually corrected
}
private TrainingExample toTrainingExample(ConversationLog log) {
return TrainingExample.builder()
.systemPrompt(buildSystemPrompt(log.getContext()))
.userMessage(log.getUserMessage())
.assistantResponse(log.getAssistantMessage())
.build();
}
/**
* Split data for training/validation
*/
public DataSplit splitData(List<TrainingExample> examples, double validationRatio) {
Collections.shuffle(examples);
int validationSize = (int) (examples.size() * validationRatio);
return new DataSplit(
examples.subList(validationSize, examples.size()), // Training
examples.subList(0, validationSize) // Validation
);
}
/**
* Augment training data
*/
public List<TrainingExample> augmentData(List<TrainingExample> examples) {
List<TrainingExample> augmented = new ArrayList<>(examples);
for (TrainingExample example : examples) {
// Paraphrase user messages
String paraphrasedUser = paraphraseService.paraphrase(example.getUserMessage());
augmented.add(example.withUserMessage(paraphrasedUser));
// Add typos/variations for robustness
String withTypos = addRealisticVariations(example.getUserMessage());
augmented.add(example.withUserMessage(withTypos));
}
return augmented;
}
/**
* Validate training data format
*/
public ValidationResult validateTrainingData(List<TrainingExample> examples) {
List<String> errors = new ArrayList<>();
List<String> warnings = new ArrayList<>();
// Check minimum size
if (examples.size() < 10) {
errors.add("Need at least 10 examples, got " + examples.size());
}
// Check for duplicates
Set<String> seen = new HashSet<>();
for (TrainingExample ex : examples) {
String key = ex.getUserMessage() + "|||" + ex.getAssistantResponse();
if (!seen.add(key)) {
warnings.add("Duplicate example found");
}
}
// Check token counts
for (int i = 0; i < examples.size(); i++) {
TrainingExample ex = examples.get(i);
int tokens = countTokens(ex);
if (tokens > 16000) {
errors.add("Example " + i + " exceeds token limit: " + tokens);
}
}
// Check for data leakage
checkDataLeakage(examples, warnings);
return new ValidationResult(errors.isEmpty(), errors, warnings);
}
record DataSplit(List<TrainingExample> training, List<TrainingExample> validation) {}
}
Example 3: LoRA Fine-Tuning (Python/Conceptual)
/**
* This example shows the configuration for LoRA fine-tuning.
* Actual training typically runs in Python with HuggingFace.
*/
@Service
public class LoraConfigService {
/**
* Generate LoRA training configuration
*/
public LoraTrainingConfig generateConfig(LoraRequest request) {
return LoraTrainingConfig.builder()
// LoRA specific parameters
.loraR(16) // Rank of update matrices
.loraAlpha(32) // Scaling factor
.loraDropout(0.1) // Dropout for regularization
.targetModules(List.of( // Which layers to adapt
"q_proj", "k_proj", "v_proj", "o_proj", // Attention
"gate_proj", "up_proj", "down_proj" // MLP
))
// Training parameters
.baseModel(request.getBaseModel())
.learningRate(2e-4)
.batchSize(4)
.gradientAccumulationSteps(4) // Effective batch = 16
.epochs(3)
.warmupRatio(0.03)
.weightDecay(0.01)
// Quantization for QLoRA
.quantization(request.isUseQlora() ? "4bit" : null)
.bnbConfig(request.isUseQlora() ? BnbConfig.builder()
.loadIn4bit(true)
.bnb4bitQuant("nf4")
.bnb4bitComputeDtype("bfloat16")
.bnb4bitUseDoubleQuant(true)
.build() : null)
// Output
.outputDir(request.getOutputPath())
.loggingSteps(10)
.saveSteps(500)
.evalSteps(500)
.build();
}
/**
* Generate Python training script
*/
public String generateTrainingScript(LoraTrainingConfig config) {
return """
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
import torch
# Load base model
model = AutoModelForCausalLM.from_pretrained(
"%s",
torch_dtype=torch.bfloat16,
device_map="auto",
%s
)
tokenizer = AutoTokenizer.from_pretrained("%s")
tokenizer.pad_token = tokenizer.eos_token
# Configure LoRA
lora_config = LoraConfig(
r=%d,
lora_alpha=%d,
lora_dropout=%f,
target_modules=%s,
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Training arguments
training_args = TrainingArguments(
output_dir="%s",
per_device_train_batch_size=%d,
gradient_accumulation_steps=%d,
learning_rate=%e,
num_train_epochs=%d,
warmup_ratio=%f,
weight_decay=%f,
logging_steps=%d,
save_steps=%d,
eval_steps=%d,
bf16=True,
)
# Create trainer
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
tokenizer=tokenizer,
max_seq_length=2048,
)
# Train
trainer.train()
# Save LoRA adapter
model.save_pretrained("%s/lora_adapter")
""".formatted(
config.getBaseModel(),
config.getBnbConfig() != null ? generateBnbConfig(config.getBnbConfig()) : "",
config.getBaseModel(),
config.getLoraR(),
config.getLoraAlpha(),
config.getLoraDropout(),
config.getTargetModules(),
config.getOutputDir(),
config.getBatchSize(),
config.getGradientAccumulationSteps(),
config.getLearningRate(),
config.getEpochs(),
config.getWarmupRatio(),
config.getWeightDecay(),
config.getLoggingSteps(),
config.getSaveSteps(),
config.getEvalSteps(),
config.getOutputDir()
);
}
}
Example 4: Model Evaluation
@Service
public class FineTunedModelEvaluator {
private final LlmClient llmClient;
/**
* Evaluate fine-tuned model against test set
*/
public EvaluationReport evaluate(String modelId, List<TestCase> testCases) {
List<EvaluationResult> results = new ArrayList<>();
for (TestCase testCase : testCases) {
// Get model response
CompletionResponse response = llmClient.complete(
CompletionRequest.builder()
.model(modelId)
.messages(List.of(
Message.system(testCase.getSystemPrompt()),
Message.user(testCase.getInput())
))
.temperature(0.0) // Deterministic for evaluation
.build()
);
String generated = response.getContent();
String expected = testCase.getExpectedOutput();
// Calculate metrics
double exactMatch = generated.trim().equals(expected.trim()) ? 1.0 : 0.0;
double similarity = calculateSimilarity(generated, expected);
double formatCompliance = checkFormatCompliance(generated, testCase.getExpectedFormat());
results.add(EvaluationResult.builder()
.testCaseId(testCase.getId())
.input(testCase.getInput())
.expected(expected)
.generated(generated)
.exactMatch(exactMatch)
.similarity(similarity)
.formatCompliance(formatCompliance)
.build());
}
return EvaluationReport.builder()
.modelId(modelId)
.totalTestCases(testCases.size())
.avgExactMatch(average(results, EvaluationResult::getExactMatch))
.avgSimilarity(average(results, EvaluationResult::getSimilarity))
.avgFormatCompliance(average(results, EvaluationResult::getFormatCompliance))
.results(results)
.build();
}
/**
* Compare fine-tuned model with base model
*/
public ComparisonReport compareWithBase(
String fineTunedModel,
String baseModel,
List<TestCase> testCases) {
EvaluationReport fineTunedResults = evaluate(fineTunedModel, testCases);
EvaluationReport baseResults = evaluate(baseModel, testCases);
return ComparisonReport.builder()
.fineTunedModel(fineTunedModel)
.baseModel(baseModel)
.fineTunedMetrics(fineTunedResults)
.baseMetrics(baseResults)
.exactMatchImprovement(
fineTunedResults.getAvgExactMatch() - baseResults.getAvgExactMatch())
.similarityImprovement(
fineTunedResults.getAvgSimilarity() - baseResults.getAvgSimilarity())
.build();
}
/**
* Test for regression on general capabilities
*/
public RegressionReport checkRegression(
String fineTunedModel,
String baseModel,
List<GeneralCapabilityTest> generalTests) {
List<RegressionResult> results = new ArrayList<>();
for (GeneralCapabilityTest test : generalTests) {
// Test both models
CompletionResponse fineTunedResponse = runTest(fineTunedModel, test);
CompletionResponse baseResponse = runTest(baseModel, test);
// Evaluate responses
double fineTunedScore = evaluateGeneral(fineTunedResponse, test);
double baseScore = evaluateGeneral(baseResponse, test);
results.add(new RegressionResult(
test.getCapability(),
baseScore,
fineTunedScore,
fineTunedScore >= baseScore * 0.95 // Allow 5% degradation
));
}
boolean hasRegression = results.stream().anyMatch(r -> !r.passed());
return new RegressionReport(results, !hasRegression);
}
}
Example 5: Model Versioning
@Service
public class ModelVersioningService {
private final ModelRepository modelRepository;
private final ArtifactStorage artifactStorage;
/**
* Register a new fine-tuned model version
*/
public ModelVersion registerModel(RegisterModelRequest request) {
ModelVersion version = ModelVersion.builder()
.id(UUID.randomUUID().toString())
.name(request.getModelName())
.version(calculateNextVersion(request.getModelName()))
.baseModel(request.getBaseModel())
.fineTunedModelId(request.getFineTunedModelId())
.trainingConfig(request.getTrainingConfig())
.trainingMetrics(request.getTrainingMetrics())
.evaluationReport(request.getEvaluationReport())
.datasetVersion(request.getDatasetVersion())
.createdAt(Instant.now())
.createdBy(request.getCreatedBy())
.status(ModelStatus.STAGING)
.build();
// Store training artifacts
if (request.getLoraAdapter() != null) {
String adapterPath = artifactStorage.store(
request.getLoraAdapter(),
"models/" + version.getId() + "/adapter"
);
version.setAdapterPath(adapterPath);
}
modelRepository.save(version);
return version;
}
/**
* Promote model to production
*/
public void promoteToProduction(String modelId, String reason) {
ModelVersion model = modelRepository.findById(modelId)
.orElseThrow(() -> new NotFoundException("Model not found"));
// Demote current production model
modelRepository.findByNameAndStatus(model.getName(), ModelStatus.PRODUCTION)
.ifPresent(current -> {
current.setStatus(ModelStatus.ARCHIVED);
modelRepository.save(current);
});
// Promote new model
model.setStatus(ModelStatus.PRODUCTION);
model.setPromotedAt(Instant.now());
model.setPromotionReason(reason);
modelRepository.save(model);
// Update model router
modelRouter.setActiveModel(model.getName(), model.getFineTunedModelId());
}
/**
* A/B test models
*/
public ABTestConfig createABTest(ABTestRequest request) {
return ABTestConfig.builder()
.id(UUID.randomUUID().toString())
.name(request.getName())
.controlModel(request.getControlModelId())
.treatmentModel(request.getTreatmentModelId())
.trafficSplit(request.getTrafficSplit()) // e.g., 0.1 = 10% to treatment
.metrics(request.getMetrics())
.startDate(request.getStartDate())
.endDate(request.getEndDate())
.status(ABTestStatus.RUNNING)
.build();
}
/**
* Get model for inference with A/B routing
*/
public String getModelForRequest(String modelName, String requestId) {
// Check for active A/B test
Optional<ABTestConfig> activeTest = abTestRepository
.findActiveByModelName(modelName);
if (activeTest.isPresent()) {
ABTestConfig test = activeTest.get();
// Deterministic routing based on request ID
double hash = Math.abs(requestId.hashCode()) / (double) Integer.MAX_VALUE;
String selectedModel = hash < test.getTrafficSplit()
? test.getTreatmentModel()
: test.getControlModel();
// Log for analysis
logABTestAssignment(test.getId(), requestId, selectedModel);
return selectedModel;
}
// Return production model
return modelRepository.findByNameAndStatus(modelName, ModelStatus.PRODUCTION)
.map(ModelVersion::getFineTunedModelId)
.orElseThrow(() -> new NotFoundException("No production model for: " + modelName));
}
}
Anti-Patterns
❌ Fine-Tuning Without Evaluation
Always measure performance before and after.
❌ Overfitting on Small Dataset
More data diversity, not just more epochs.