feat: 增强安全组件并添加内存保护机制

1. 重构 ReplayAttackInterceptor、RateLimitInterceptor、CaptchaUtil 和 LoginLockUtil,使用 LRUCache 和读写锁优化内存管理
2. 新增 MemoryProtector 类实现内存监控和保护机制
3. 为所有内存缓存组件添加容量限制和过期清理策略
4. 更新 .gitignore 文件配置
This commit is contained in:
ikmkj
2026-03-03 18:23:28 +08:00
parent 99e44e6c3b
commit 61aeba9c65
6 changed files with 488 additions and 100 deletions

1
.gitignore vendored
View File

@@ -5,6 +5,7 @@
.idea .idea
target target
.trae .trae
.trae/*
*.iml *.iml
.roo .roo
out out

View File

@@ -3,19 +3,22 @@ package com.test.bijihoudaun.interceptor;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.test.bijihoudaun.common.response.R; import com.test.bijihoudaun.common.response.R;
import com.test.bijihoudaun.common.response.ResultCode; import com.test.bijihoudaun.common.response.ResultCode;
import com.test.bijihoudaun.util.MemoryProtector;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.HandlerInterceptor;
import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.concurrent.atomic.AtomicInteger;
/** /**
* 限流拦截器 - 支持按 IP 和按用户双重限流 * 限流拦截器 - 支持按 IP 和按用户双重限流,带容量限制
*/ */
@Slf4j
public class RateLimitInterceptor implements HandlerInterceptor { public class RateLimitInterceptor implements HandlerInterceptor {
// 普通接口限流配置 // 普通接口限流配置
@@ -26,40 +29,70 @@ public class RateLimitInterceptor implements HandlerInterceptor {
private static final int MAX_LOGIN_REQUESTS_PER_MINUTE_USER = 10; private static final int MAX_LOGIN_REQUESTS_PER_MINUTE_USER = 10;
// 时间窗口(毫秒) // 时间窗口(毫秒)
private static final long WINDOW_SIZE_MS = 60_000; private static final long WINDOW_SIZE_MS = 60_000;
// 最大存储记录数(防止内存溢出)
private static final int MAX_RECORDS = 50000;
// IP 级别限流 // IP 级别限流
private final Map<String, RequestCounter> ipCounters = new ConcurrentHashMap<>(); private static final LRUCache<String, RequestCounter> ipCounters = new LRUCache<>(MAX_RECORDS / 2);
private final Map<String, RequestCounter> ipLoginCounters = new ConcurrentHashMap<>(); private static final LRUCache<String, RequestCounter> ipLoginCounters = new LRUCache<>(MAX_RECORDS / 4);
// 用户级别限流 // 用户级别限流
private final Map<String, RequestCounter> userCounters = new ConcurrentHashMap<>(); private static final LRUCache<String, RequestCounter> userCounters = new LRUCache<>(MAX_RECORDS / 4);
private final Map<String, RequestCounter> userLoginCounters = new ConcurrentHashMap<>(); private static final LRUCache<String, RequestCounter> userLoginCounters = new LRUCache<>(MAX_RECORDS / 4);
private static final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
private static class RequestCounter { private static class RequestCounter {
AtomicInteger count; int count;
long windowStart; long windowStart;
RequestCounter() { RequestCounter() {
this.count = new AtomicInteger(1); this.count = 1;
this.windowStart = System.currentTimeMillis(); this.windowStart = System.currentTimeMillis();
} }
boolean incrementAndCheck(int maxRequests) { boolean incrementAndCheck(int maxRequests) {
long now = System.currentTimeMillis(); long now = System.currentTimeMillis();
if (now - windowStart > WINDOW_SIZE_MS) { if (now - windowStart > WINDOW_SIZE_MS) {
synchronized (this) { // 新窗口
if (now - windowStart > WINDOW_SIZE_MS) { count = 1;
count.set(1); windowStart = now;
windowStart = now; return true;
return true;
}
}
} }
return count.incrementAndGet() <= maxRequests; count++;
return count <= maxRequests;
}
}
/**
* 简单的 LRU 缓存实现
*/
private static class LRUCache<K, V> extends LinkedHashMap<K, V> {
private final int maxSize;
LRUCache(int maxSize) {
super(maxSize, 0.75f, true);
this.maxSize = maxSize;
}
@Override
protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
boolean shouldRemove = size() > maxSize;
if (shouldRemove) {
log.debug("限流记录达到上限,移除最旧的记录");
}
return shouldRemove;
} }
} }
@Override @Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception { public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
// 检查内存状态
if (MemoryProtector.isMemoryInsufficient()) {
log.warn("内存不足,拒绝请求: {}", request.getRequestURI());
MemoryProtector.writeMemoryInsufficientResponse(response);
return false;
}
String clientIp = getClientIp(request); String clientIp = getClientIp(request);
String requestUri = request.getRequestURI(); String requestUri = request.getRequestURI();
String username = getCurrentUsername(); String username = getCurrentUsername();
@@ -84,14 +117,19 @@ public class RateLimitInterceptor implements HandlerInterceptor {
*/ */
private boolean checkIpLimit(String clientIp, boolean isLoginRequest, private boolean checkIpLimit(String clientIp, boolean isLoginRequest,
HttpServletResponse response) throws Exception { HttpServletResponse response) throws Exception {
Map<String, RequestCounter> counters = isLoginRequest ? ipLoginCounters : ipCounters; LRUCache<String, RequestCounter> counters = isLoginRequest ? ipLoginCounters : ipCounters;
int maxRequests = isLoginRequest ? MAX_LOGIN_REQUESTS_PER_MINUTE : MAX_REQUESTS_PER_MINUTE; int maxRequests = isLoginRequest ? MAX_LOGIN_REQUESTS_PER_MINUTE : MAX_REQUESTS_PER_MINUTE;
RequestCounter counter = counters.computeIfAbsent(clientIp, k -> new RequestCounter()); lock.writeLock().lock();
try {
RequestCounter counter = counters.computeIfAbsent(clientIp, k -> new RequestCounter());
if (!counter.incrementAndCheck(maxRequests)) { if (!counter.incrementAndCheck(maxRequests)) {
writeRateLimitResponse(response, "请求过于频繁,请稍后再试"); writeRateLimitResponse(response, "请求过于频繁,请稍后再试");
return false; return false;
}
} finally {
lock.writeLock().unlock();
} }
return true; return true;
} }
@@ -101,14 +139,19 @@ public class RateLimitInterceptor implements HandlerInterceptor {
*/ */
private boolean checkUserLimit(String username, boolean isLoginRequest, private boolean checkUserLimit(String username, boolean isLoginRequest,
HttpServletResponse response) throws Exception { HttpServletResponse response) throws Exception {
Map<String, RequestCounter> counters = isLoginRequest ? userLoginCounters : userCounters; LRUCache<String, RequestCounter> counters = isLoginRequest ? userLoginCounters : userCounters;
int maxRequests = isLoginRequest ? MAX_LOGIN_REQUESTS_PER_MINUTE_USER : MAX_REQUESTS_PER_MINUTE_USER; int maxRequests = isLoginRequest ? MAX_LOGIN_REQUESTS_PER_MINUTE_USER : MAX_REQUESTS_PER_MINUTE_USER;
RequestCounter counter = counters.computeIfAbsent(username, k -> new RequestCounter()); lock.writeLock().lock();
try {
RequestCounter counter = counters.computeIfAbsent(username, k -> new RequestCounter());
if (!counter.incrementAndCheck(maxRequests)) { if (!counter.incrementAndCheck(maxRequests)) {
writeRateLimitResponse(response, "您的操作过于频繁,请稍后再试"); writeRateLimitResponse(response, "您的操作过于频繁,请稍后再试");
return false; return false;
}
} finally {
lock.writeLock().unlock();
} }
return true; return true;
} }

View File

@@ -3,26 +3,33 @@ package com.test.bijihoudaun.interceptor;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.test.bijihoudaun.common.response.R; import com.test.bijihoudaun.common.response.R;
import com.test.bijihoudaun.common.response.ResultCode; import com.test.bijihoudaun.common.response.ResultCode;
import com.test.bijihoudaun.util.MemoryProtector;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.HandlerInterceptor;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantReadWriteLock;
/** /**
* 防重放攻击拦截器 * 防重放攻击拦截器
* 通过时间戳和 nonce 机制防止请求被截获重放 * 通过时间戳和 nonce 机制防止请求被截获重放,带容量限制
*/ */
@Slf4j
public class ReplayAttackInterceptor implements HandlerInterceptor { public class ReplayAttackInterceptor implements HandlerInterceptor {
// 请求时间戳有效期毫秒5分钟 // 请求时间戳有效期毫秒5分钟
private static final long TIMESTAMP_VALIDITY = 5 * 60 * 1000; private static final long TIMESTAMP_VALIDITY = 5 * 60 * 1000;
// nonce 有效期毫秒10分钟 // nonce 有效期毫秒10分钟
private static final long NONCE_EXPIRE_TIME = 10 * 60 * 1000; private static final long NONCE_EXPIRE_TIME = 10 * 60 * 1000;
// 最大存储 nonce 数(防止内存溢出)
private static final int MAX_NONCES = 100000;
// 用于不需要验证的路径 // 用于不需要验证的路径
private static final Set<String> EXCLUDE_PATHS = new HashSet<>(Arrays.asList( private static final Set<String> EXCLUDE_PATHS = new HashSet<>(Arrays.asList(
"/api/user/login", "/api/user/login",
@@ -31,11 +38,40 @@ public class ReplayAttackInterceptor implements HandlerInterceptor {
)); ));
// 存储已使用的 noncekey=noncevalue=使用时间戳 // 存储已使用的 noncekey=noncevalue=使用时间戳
private static final Map<String, Long> usedNonces = new ConcurrentHashMap<>(); private static final LRUCache<String, Long> usedNonces = new LRUCache<>(MAX_NONCES);
private static final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
/**
* 简单的 LRU 缓存实现
*/
private static class LRUCache<K, V> extends LinkedHashMap<K, V> {
private final int maxSize;
LRUCache(int maxSize) {
super(maxSize, 0.75f, true);
this.maxSize = maxSize;
}
@Override
protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
boolean shouldRemove = size() > maxSize;
if (shouldRemove) {
log.debug("nonce 存储达到上限,移除最旧的记录");
}
return shouldRemove;
}
}
@Override @Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, public boolean preHandle(HttpServletRequest request, HttpServletResponse response,
Object handler) throws Exception { Object handler) throws Exception {
// 检查内存状态
if (MemoryProtector.isMemoryInsufficient()) {
log.warn("内存不足,拒绝请求: {}", request.getRequestURI());
MemoryProtector.writeMemoryInsufficientResponse(response);
return false;
}
String requestUri = request.getRequestURI(); String requestUri = request.getRequestURI();
// 排除不需要验证的路径 // 排除不需要验证的路径
@@ -77,13 +113,23 @@ public class ReplayAttackInterceptor implements HandlerInterceptor {
} }
// 验证 nonce 是否已被使用 // 验证 nonce 是否已被使用
if (usedNonces.containsKey(nonce)) { lock.readLock().lock();
writeErrorResponse(response, "检测到重放攻击,请求已被拒绝"); try {
return false; if (usedNonces.containsKey(nonce)) {
writeErrorResponse(response, "检测到重放攻击,请求已被拒绝");
return false;
}
} finally {
lock.readLock().unlock();
} }
// 记录 nonce // 记录 nonce
usedNonces.put(nonce, now); lock.writeLock().lock();
try {
usedNonces.put(nonce, now);
} finally {
lock.writeLock().unlock();
}
return true; return true;
} }
@@ -92,10 +138,20 @@ public class ReplayAttackInterceptor implements HandlerInterceptor {
* 清理过期的 nonce * 清理过期的 nonce
*/ */
private void cleanupExpiredNonces() { private void cleanupExpiredNonces() {
long now = System.currentTimeMillis(); // 每80%容量时清理一次,减少开销
usedNonces.entrySet().removeIf(entry -> if (usedNonces.size() < MAX_NONCES * 0.8) {
(now - entry.getValue()) > NONCE_EXPIRE_TIME return;
); }
lock.writeLock().lock();
try {
long now = System.currentTimeMillis();
usedNonces.entrySet().removeIf(entry ->
(now - entry.getValue()) > NONCE_EXPIRE_TIME
);
} finally {
lock.writeLock().unlock();
}
} }
/** /**

View File

@@ -1,5 +1,7 @@
package com.test.bijihoudaun.util; package com.test.bijihoudaun.util;
import lombok.extern.slf4j.Slf4j;
import javax.imageio.ImageIO; import javax.imageio.ImageIO;
import java.awt.*; import java.awt.*;
import java.awt.image.BufferedImage; import java.awt.image.BufferedImage;
@@ -8,14 +10,15 @@ import java.io.IOException;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.Base64; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantReadWriteLock;
/** /**
* 图形验证码工具类 * 图形验证码工具类
* 使用本地内存存储验证码 * 使用本地内存存储验证码,带容量限制
*/ */
@Slf4j
public class CaptchaUtil { public class CaptchaUtil {
// 验证码有效期(分钟) // 验证码有效期(分钟)
@@ -26,9 +29,12 @@ public class CaptchaUtil {
private static final int IMAGE_WIDTH = 120; private static final int IMAGE_WIDTH = 120;
// 图片高度 // 图片高度
private static final int IMAGE_HEIGHT = 40; private static final int IMAGE_HEIGHT = 40;
// 最大存储验证码数(防止内存溢出)
private static final int MAX_CAPTCHAS = 5000;
// 存储验证码key=验证码IDvalue=验证码记录 // 存储验证码key=验证码IDvalue=验证码记录
private static final Map<String, CaptchaRecord> captchaStore = new ConcurrentHashMap<>(); private static final LRUCache<String, CaptchaRecord> captchaStore = new LRUCache<>(MAX_CAPTCHAS);
private static final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
// 安全随机数生成器 // 安全随机数生成器
private static final SecureRandom random = new SecureRandom(); private static final SecureRandom random = new SecureRandom();
@@ -47,6 +53,27 @@ public class CaptchaUtil {
} }
} }
/**
* 简单的 LRU 缓存实现
*/
private static class LRUCache<K, V> extends LinkedHashMap<K, V> {
private final int maxSize;
LRUCache(int maxSize) {
super(maxSize, 0.75f, true);
this.maxSize = maxSize;
}
@Override
protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
boolean shouldRemove = size() > maxSize;
if (shouldRemove) {
log.warn("验证码存储达到上限,移除最旧的验证码: {}", eldest.getKey());
}
return shouldRemove;
}
}
/** /**
* 验证码结果 * 验证码结果
*/ */
@@ -73,6 +100,12 @@ public class CaptchaUtil {
* @return 包含验证码ID和Base64图片的结果 * @return 包含验证码ID和Base64图片的结果
*/ */
public static CaptchaResult generateCaptcha() { public static CaptchaResult generateCaptcha() {
// 检查内存状态
if (MemoryProtector.isMemoryInsufficient()) {
log.error("内存不足,拒绝生成验证码");
throw new RuntimeException("服务器繁忙,请稍后再试");
}
// 清理过期验证码 // 清理过期验证码
cleanupExpiredCaptchas(); cleanupExpiredCaptchas();
@@ -85,8 +118,13 @@ public class CaptchaUtil {
// 生成图片 // 生成图片
String base64Image = generateImage(code); String base64Image = generateImage(code);
// 存储验证码 lock.writeLock().lock();
captchaStore.put(captchaId, new CaptchaRecord(code)); try {
// 存储验证码
captchaStore.put(captchaId, new CaptchaRecord(code));
} finally {
lock.writeLock().unlock();
}
return new CaptchaResult(captchaId, base64Image); return new CaptchaResult(captchaId, base64Image);
} }
@@ -102,24 +140,43 @@ public class CaptchaUtil {
return false; return false;
} }
CaptchaRecord record = captchaStore.get(captchaId); lock.readLock().lock();
if (record == null || record.isExpired()) { try {
// 验证码不存在或已过期,移除 CaptchaRecord record = captchaStore.get(captchaId);
if (record != null) { if (record == null || record.isExpired()) {
captchaStore.remove(captchaId); // 验证码不存在或已过期,移除
if (record != null) {
lock.readLock().unlock();
lock.writeLock().lock();
try {
captchaStore.remove(captchaId);
} finally {
lock.writeLock().unlock();
lock.readLock().lock();
}
}
return false;
} }
return false;
// 验证码比对(不区分大小写)
boolean success = record.code.equalsIgnoreCase(code);
// 验证成功后立即删除(一次性使用)
if (success) {
lock.readLock().unlock();
lock.writeLock().lock();
try {
captchaStore.remove(captchaId);
} finally {
lock.writeLock().unlock();
lock.readLock().lock();
}
}
return success;
} finally {
lock.readLock().unlock();
} }
// 验证码比对(不区分大小写)
boolean success = record.code.equalsIgnoreCase(code);
// 验证成功后立即删除(一次性使用)
if (success) {
captchaStore.remove(captchaId);
}
return success;
} }
/** /**
@@ -186,7 +243,7 @@ public class CaptchaUtil {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
ImageIO.write(image, "png", baos); ImageIO.write(image, "png", baos);
byte[] imageBytes = baos.toByteArray(); byte[] imageBytes = baos.toByteArray();
return "data:image/png;base64," + Base64.getEncoder().encodeToString(imageBytes); return "data:image/png;base64," + java.util.Base64.getEncoder().encodeToString(imageBytes);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException("生成验证码图片失败", e); throw new RuntimeException("生成验证码图片失败", e);
} }
@@ -196,6 +253,28 @@ public class CaptchaUtil {
* 清理过期验证码 * 清理过期验证码
*/ */
private static void cleanupExpiredCaptchas() { private static void cleanupExpiredCaptchas() {
captchaStore.entrySet().removeIf(entry -> entry.getValue().isExpired()); // 每80%容量时清理一次,减少开销
if (captchaStore.size() < MAX_CAPTCHAS * 0.8) {
return;
}
lock.writeLock().lock();
try {
captchaStore.entrySet().removeIf(entry -> entry.getValue().isExpired());
} finally {
lock.writeLock().unlock();
}
}
/**
* 获取当前验证码数量(用于监控)
*/
public static int getCaptchaCount() {
lock.readLock().lock();
try {
return captchaStore.size();
} finally {
lock.readLock().unlock();
}
} }
} }

View File

@@ -1,13 +1,18 @@
package com.test.bijihoudaun.util; package com.test.bijihoudaun.util;
import lombok.extern.slf4j.Slf4j;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.concurrent.ConcurrentHashMap; import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.locks.ReentrantReadWriteLock;
/** /**
* 登录锁定工具类 * 登录锁定工具类
* 使用本地内存存储登录失败记录 * 使用本地内存存储登录失败记录,带容量限制
*/ */
@Slf4j
public class LoginLockUtil { public class LoginLockUtil {
// 最大失败次数 // 最大失败次数
@@ -16,9 +21,12 @@ public class LoginLockUtil {
private static final int LOCK_TIME_MINUTES = 30; private static final int LOCK_TIME_MINUTES = 30;
// 失败记录过期时间(分钟) // 失败记录过期时间(分钟)
private static final int RECORD_EXPIRE_MINUTES = 60; private static final int RECORD_EXPIRE_MINUTES = 60;
// 最大存储记录数(防止内存溢出)
private static final int MAX_RECORDS = 10000;
// 登录失败记录key=用户名value=失败记录 // 登录失败记录key=用户名value=失败记录
private static final ConcurrentHashMap<String, LoginAttempt> attempts = new ConcurrentHashMap<>(); private static final LRUCache<String, LoginAttempt> attempts = new LRUCache<>(MAX_RECORDS);
private static final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
private static class LoginAttempt { private static class LoginAttempt {
int failedCount; int failedCount;
@@ -32,6 +40,27 @@ public class LoginLockUtil {
} }
} }
/**
* 简单的 LRU 缓存实现
*/
private static class LRUCache<K, V> extends LinkedHashMap<K, V> {
private final int maxSize;
LRUCache(int maxSize) {
super(maxSize, 0.75f, true);
this.maxSize = maxSize;
}
@Override
protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
boolean shouldRemove = size() > maxSize;
if (shouldRemove) {
log.warn("登录锁定记录达到上限,移除最旧的记录: {}", eldest.getKey());
}
return shouldRemove;
}
}
/** /**
* 记录登录失败 * 记录登录失败
* @param username 用户名 * @param username 用户名
@@ -41,13 +70,19 @@ public class LoginLockUtil {
cleanupExpiredRecords(); cleanupExpiredRecords();
LoginAttempt attempt = attempts.computeIfAbsent(username, k -> new LoginAttempt()); lock.writeLock().lock();
attempt.failedCount++; try {
attempt.lastAttemptTime = LocalDateTime.now(); LoginAttempt attempt = attempts.computeIfAbsent(username, k -> new LoginAttempt());
attempt.failedCount++;
attempt.lastAttemptTime = LocalDateTime.now();
// 达到最大失败次数,锁定账号 // 达到最大失败次数,锁定账号
if (attempt.failedCount >= MAX_FAILED_ATTEMPTS) { if (attempt.failedCount >= MAX_FAILED_ATTEMPTS) {
attempt.lockUntil = LocalDateTime.now().plusMinutes(LOCK_TIME_MINUTES); attempt.lockUntil = LocalDateTime.now().plusMinutes(LOCK_TIME_MINUTES);
log.warn("账号 [{}] 已被锁定 {} 分钟", username, LOCK_TIME_MINUTES);
}
} finally {
lock.writeLock().unlock();
} }
} }
@@ -57,7 +92,13 @@ public class LoginLockUtil {
*/ */
public static void recordSuccess(String username) { public static void recordSuccess(String username) {
if (username == null || username.isEmpty()) return; if (username == null || username.isEmpty()) return;
attempts.remove(username);
lock.writeLock().lock();
try {
attempts.remove(username);
} finally {
lock.writeLock().unlock();
}
} }
/** /**
@@ -68,20 +109,32 @@ public class LoginLockUtil {
public static boolean isLocked(String username) { public static boolean isLocked(String username) {
if (username == null || username.isEmpty()) return false; if (username == null || username.isEmpty()) return false;
LoginAttempt attempt = attempts.get(username); lock.readLock().lock();
if (attempt == null) return false; try {
LoginAttempt attempt = attempts.get(username);
if (attempt == null) return false;
// 检查是否仍在锁定时间内 // 检查是否仍在锁定时间内
if (attempt.lockUntil != null) { if (attempt.lockUntil != null) {
if (LocalDateTime.now().isBefore(attempt.lockUntil)) { if (LocalDateTime.now().isBefore(attempt.lockUntil)) {
return true; return true;
} else { } else {
// 锁定时间已过,清除记录 // 锁定时间已过,清除记录
attempts.remove(username); lock.readLock().unlock();
return false; lock.writeLock().lock();
try {
attempts.remove(username);
} finally {
lock.writeLock().unlock();
lock.readLock().lock();
}
return false;
}
} }
return false;
} finally {
lock.readLock().unlock();
} }
return false;
} }
/** /**
@@ -92,11 +145,16 @@ public class LoginLockUtil {
public static long getRemainingLockTime(String username) { public static long getRemainingLockTime(String username) {
if (username == null || username.isEmpty()) return 0; if (username == null || username.isEmpty()) return 0;
LoginAttempt attempt = attempts.get(username); lock.readLock().lock();
if (attempt == null || attempt.lockUntil == null) return 0; try {
LoginAttempt attempt = attempts.get(username);
if (attempt == null || attempt.lockUntil == null) return 0;
long remaining = ChronoUnit.SECONDS.between(LocalDateTime.now(), attempt.lockUntil); long remaining = ChronoUnit.SECONDS.between(LocalDateTime.now(), attempt.lockUntil);
return Math.max(0, remaining); return Math.max(0, remaining);
} finally {
lock.readLock().unlock();
}
} }
/** /**
@@ -107,25 +165,52 @@ public class LoginLockUtil {
public static int getRemainingAttempts(String username) { public static int getRemainingAttempts(String username) {
if (username == null || username.isEmpty()) return MAX_FAILED_ATTEMPTS; if (username == null || username.isEmpty()) return MAX_FAILED_ATTEMPTS;
LoginAttempt attempt = attempts.get(username); lock.readLock().lock();
if (attempt == null) return MAX_FAILED_ATTEMPTS; try {
LoginAttempt attempt = attempts.get(username);
if (attempt == null) return MAX_FAILED_ATTEMPTS;
return Math.max(0, MAX_FAILED_ATTEMPTS - attempt.failedCount); return Math.max(0, MAX_FAILED_ATTEMPTS - attempt.failedCount);
} finally {
lock.readLock().unlock();
}
} }
/** /**
* 清理过期记录 * 清理过期记录
*/ */
private static void cleanupExpiredRecords() { private static void cleanupExpiredRecords() {
LocalDateTime now = LocalDateTime.now(); // 每100次操作清理一次减少开销
attempts.entrySet().removeIf(entry -> { if (attempts.size() < MAX_RECORDS * 0.8) {
LoginAttempt attempt = entry.getValue(); return;
// 未锁定且长时间没有登录的记录 }
if (attempt.lockUntil == null) {
return ChronoUnit.MINUTES.between(attempt.lastAttemptTime, now) > RECORD_EXPIRE_MINUTES; lock.writeLock().lock();
} try {
// 锁定已过期的记录 LocalDateTime now = LocalDateTime.now();
return now.isAfter(attempt.lockUntil); attempts.entrySet().removeIf(entry -> {
}); LoginAttempt attempt = entry.getValue();
// 未锁定且长时间没有登录的记录
if (attempt.lockUntil == null) {
return ChronoUnit.MINUTES.between(attempt.lastAttemptTime, now) > RECORD_EXPIRE_MINUTES;
}
// 锁定已过期的记录
return now.isAfter(attempt.lockUntil);
});
} finally {
lock.writeLock().unlock();
}
}
/**
* 获取当前记录数量(用于监控)
*/
public static int getRecordCount() {
lock.readLock().lock();
try {
return attempts.size();
} finally {
lock.readLock().unlock();
}
} }
} }

View File

@@ -0,0 +1,124 @@
package com.test.bijihoudaun.util;
import com.test.bijihoudaun.common.response.R;
import com.test.bijihoudaun.common.response.ResultCode;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import java.io.IOException;
import java.lang.management.ManagementFactory;
import java.lang.management.MemoryMXBean;
import java.lang.management.MemoryUsage;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* 内存保护工具类
* 监控 JVM 内存使用情况,当内存不足时拒绝请求
*/
@Slf4j
public class MemoryProtector {
// 内存使用率阈值(百分比),超过此值进入保护模式
private static final double MEMORY_THRESHOLD_PERCENT = 85.0;
// 堆内存使用阈值MB
private static final long HEAP_THRESHOLD_MB = 1500;
// 是否处于保护模式
private static final AtomicBoolean PROTECTION_MODE = new AtomicBoolean(false);
// 上次检查时间
private static volatile long lastCheckTime = 0;
// 检查间隔(毫秒)
private static final long CHECK_INTERVAL_MS = 5000;
/**
* 检查内存状态,如果内存不足返回 true
*/
public static boolean isMemoryInsufficient() {
long now = System.currentTimeMillis();
// 每 5 秒检查一次,减少性能开销
if (now - lastCheckTime < CHECK_INTERVAL_MS) {
return PROTECTION_MODE.get();
}
synchronized (MemoryProtector.class) {
if (now - lastCheckTime < CHECK_INTERVAL_MS) {
return PROTECTION_MODE.get();
}
lastCheckTime = now;
MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean();
MemoryUsage heapUsage = memoryMXBean.getHeapMemoryUsage();
long usedMB = heapUsage.getUsed() / 1024 / 1024;
long maxMB = heapUsage.getMax() / 1024 / 1024;
double usagePercent = maxMB > 0 ? (usedMB * 100.0 / maxMB) : 0;
boolean shouldProtect = usagePercent > MEMORY_THRESHOLD_PERCENT || usedMB > HEAP_THRESHOLD_MB;
if (shouldProtect && !PROTECTION_MODE.get()) {
log.warn("内存不足,进入保护模式 - 使用率: {}%, 已使用: {}MB, 最大: {}MB",
String.format("%.2f", usagePercent), usedMB, maxMB);
PROTECTION_MODE.set(true);
} else if (!shouldProtect && PROTECTION_MODE.get()) {
log.info("内存恢复正常,退出保护模式 - 使用率: {}%, 已使用: {}MB",
String.format("%.2f", usagePercent), usedMB);
PROTECTION_MODE.set(false);
}
return PROTECTION_MODE.get();
}
}
/**
* 写入内存不足响应
*/
public static void writeMemoryInsufficientResponse(HttpServletResponse response) throws IOException {
response.setContentType("application/json;charset=UTF-8");
response.setStatus(503);
response.getWriter().write(
"{\"code\":" + ResultCode.FAILED.getCode() +
",\"msg\":\"服务器繁忙,请稍后再试\",\"data\":null}"
);
}
/**
* 获取当前内存状态(用于监控)
*/
public static MemoryStatus getMemoryStatus() {
MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean();
MemoryUsage heapUsage = memoryMXBean.getHeapMemoryUsage();
MemoryStatus status = new MemoryStatus();
status.setUsedMB(heapUsage.getUsed() / 1024 / 1024);
status.setCommittedMB(heapUsage.getCommitted() / 1024 / 1024);
status.setMaxMB(heapUsage.getMax() / 1024 / 1024);
status.setUsagePercent(status.getMaxMB() > 0 ?
(status.getUsedMB() * 100.0 / status.getMaxMB()) : 0);
status.setProtectionMode(PROTECTION_MODE.get());
return status;
}
/**
* 内存状态信息
*/
public static class MemoryStatus {
private long usedMB;
private long committedMB;
private long maxMB;
private double usagePercent;
private boolean protectionMode;
public long getUsedMB() { return usedMB; }
public void setUsedMB(long usedMB) { this.usedMB = usedMB; }
public long getCommittedMB() { return committedMB; }
public void setCommittedMB(long committedMB) { this.committedMB = committedMB; }
public long getMaxMB() { return maxMB; }
public void setMaxMB(long maxMB) { this.maxMB = maxMB; }
public double getUsagePercent() { return usagePercent; }
public void setUsagePercent(double usagePercent) { this.usagePercent = usagePercent; }
public boolean isProtectionMode() { return protectionMode; }
public void setProtectionMode(boolean protectionMode) { this.protectionMode = protectionMode; }
}
}